Add __xfrm_state_lookup_exact(), an identity-match like
__xfrm_state_lookup().

Wire it into every SPI-keyed control-plane path: DELSA/GETSA, UPDSA,
GETAE/NEWAE, EXPIRE, MIGRATE_STATE.

xfrm_state_add()'s duplicate-detect keeps the wildcard
__xfrm_state_locate() - unrelated, unchanged.

Fixes: 3d6acfa7641f ("xfrm: SA lookups with mark")
Signed-off-by: Antony Antony <[email protected]>
---
 include/net/xfrm.h    |  3 ++
 net/xfrm/xfrm_state.c | 80 ++++++++++++++++++++++++++++++++++++++++++++-------
 net/xfrm/xfrm_user.c  | 29 ++++++++++---------
 3 files changed, 88 insertions(+), 24 deletions(-)

diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index 519a0156a05c..f6ed590cb2ff 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -1746,6 +1746,9 @@ struct xfrm_state *xfrm_state_lookup_byaddr(struct net 
*net, u32 mark,
                                            const xfrm_address_t *saddr,
                                            u8 proto,
                                            unsigned short family);
+struct xfrm_state *xfrm_state_lookup_exact(struct net *net, const struct 
xfrm_mark *mark,
+                                          const xfrm_address_t *daddr, __be32 
spi,
+                                          u8 proto, unsigned short family);
 #ifdef CONFIG_XFRM_SUB_POLICY
 void xfrm_tmpl_sort(struct xfrm_tmpl **dst, struct xfrm_tmpl **src, int n,
                    unsigned short family);
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index c58cd024e3c6..df761ce1c290 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -1172,11 +1172,22 @@ static struct xfrm_state *__xfrm_state_lookup_all(const 
struct xfrm_hash_state_p
        return NULL;
 }
 
-static struct xfrm_state *__xfrm_state_lookup(const struct 
xfrm_hash_state_ptrs *state_ptrs,
-                                             u32 mark,
-                                             const xfrm_address_t *daddr,
-                                             __be32 spi, u8 proto,
-                                             unsigned short family)
+/* exact=false: data-plane wildcard match against x's mask. exact=true:
+ * control-plane identity match, mark and mask must both match exactly.
+ */
+static bool xfrm_state_mark_matches(const struct xfrm_state *x, u32 mark, u32 
mask, bool exact)
+{
+       if (exact)
+               return x->mark.v == mark && x->mark.m == mask;
+       return (mark & x->mark.m) == x->mark.v;
+}
+
+static struct xfrm_state *
+__xfrm_state_lookup(const struct xfrm_hash_state_ptrs *state_ptrs,
+                   u32 mark, u32 mask, bool exact,
+                   const xfrm_address_t *daddr,
+                   __be32 spi, u8 proto,
+                   unsigned short family)
 {
        unsigned int h = __xfrm_spi_hash(daddr, spi, proto, family, 
state_ptrs->hmask);
        struct xfrm_state *x;
@@ -1188,7 +1199,7 @@ static struct xfrm_state *__xfrm_state_lookup(const 
struct xfrm_hash_state_ptrs
                    !xfrm_addr_equal(&x->id.daddr, daddr, family))
                        continue;
 
-               if ((mark & x->mark.m) != x->mark.v)
+               if (!xfrm_state_mark_matches(x, mark, mask, exact))
                        continue;
                if (!xfrm_state_hold_rcu(x))
                        continue;
@@ -1198,6 +1209,17 @@ static struct xfrm_state *__xfrm_state_lookup(const 
struct xfrm_hash_state_ptrs
        return NULL;
 }
 
+static struct xfrm_state *
+__xfrm_state_lookup_exact(const struct xfrm_hash_state_ptrs *state_ptrs,
+                         const struct xfrm_mark *mark,
+                         const xfrm_address_t *daddr,
+                         __be32 spi, u8 proto,
+                         unsigned short family)
+{
+       return __xfrm_state_lookup(state_ptrs, mark->v, mark->m, true,
+                                  daddr, spi, proto, family);
+}
+
 struct xfrm_state *xfrm_input_state_lookup(struct net *net, u32 mark,
                                           const xfrm_address_t *daddr,
                                           __be32 spi, u8 proto,
@@ -1228,7 +1250,7 @@ struct xfrm_state *xfrm_input_state_lookup(struct net 
*net, u32 mark,
 
        xfrm_hash_ptrs_get(net, &state_ptrs);
 
-       x = __xfrm_state_lookup(&state_ptrs, mark, daddr, spi, proto, family);
+       x = __xfrm_state_lookup(&state_ptrs, mark, 0, false, daddr, spi, proto, 
family);
        if (x) {
                spin_lock(&net->xfrm.xfrm_state_lock);
                if (x->km.state != XFRM_STATE_VALID) {
@@ -1288,7 +1310,7 @@ __xfrm_state_locate(struct xfrm_state *x, int use_spi, 
int family)
        xfrm_hash_ptrs_get(net, &state_ptrs);
 
        if (use_spi)
-               return __xfrm_state_lookup(&state_ptrs, mark, &x->id.daddr,
+               return __xfrm_state_lookup(&state_ptrs, mark, 0, false, 
&x->id.daddr,
                                           x->id.spi, x->id.proto, family);
        else
                return __xfrm_state_lookup_byaddr(&state_ptrs, mark,
@@ -1297,6 +1319,27 @@ __xfrm_state_locate(struct xfrm_state *x, int use_spi, 
int family)
                                                  x->id.proto, family);
 }
 
+/* Used by xfrm_state_update() only; xfrm_state_add()'s dup check keeps
+ * using the wildcard __xfrm_state_locate() above.
+ */
+static inline struct xfrm_state *
+__xfrm_state_locate_exact(struct xfrm_state *x, int use_spi, int family)
+{
+       struct xfrm_hash_state_ptrs state_ptrs;
+       struct net *net = xs_net(x);
+
+       xfrm_hash_ptrs_get(net, &state_ptrs);
+
+       if (use_spi)
+               return __xfrm_state_lookup_exact(&state_ptrs, &x->mark, 
&x->id.daddr,
+                                                x->id.spi, x->id.proto, 
family);
+       else
+               return __xfrm_state_lookup_byaddr(&state_ptrs, x->mark.v & 
x->mark.m,
+                                                 &x->id.daddr,
+                                                 &x->props.saddr,
+                                                 x->id.proto, family);
+}
+
 static void xfrm_hash_grow_check(struct net *net, int have_hash_collision)
 {
        if (have_hash_collision &&
@@ -2229,7 +2272,7 @@ int xfrm_state_update(struct xfrm_state *x)
        to_put = NULL;
 
        spin_lock_bh(&net->xfrm.xfrm_state_lock);
-       x1 = __xfrm_state_locate(x, use_spi, x->props.family);
+       x1 = __xfrm_state_locate_exact(x, use_spi, x->props.family);
 
        err = -ESRCH;
        if (!x1)
@@ -2374,7 +2417,7 @@ xfrm_state_lookup(struct net *net, u32 mark, const 
xfrm_address_t *daddr, __be32
        rcu_read_lock();
        xfrm_hash_ptrs_get(net, &state_ptrs);
 
-       x = __xfrm_state_lookup(&state_ptrs, mark, daddr, spi, proto, family);
+       x = __xfrm_state_lookup(&state_ptrs, mark, 0, false, daddr, spi, proto, 
family);
        rcu_read_unlock();
        return x;
 }
@@ -2398,6 +2441,23 @@ xfrm_state_lookup_byaddr(struct net *net, u32 mark,
 }
 EXPORT_SYMBOL(xfrm_state_lookup_byaddr);
 
+struct xfrm_state *
+xfrm_state_lookup_exact(struct net *net, const struct xfrm_mark *mark,
+                       const xfrm_address_t *daddr, __be32 spi,
+                       u8 proto, unsigned short family)
+{
+       struct xfrm_hash_state_ptrs state_ptrs;
+       struct xfrm_state *x;
+
+       rcu_read_lock();
+       xfrm_hash_ptrs_get(net, &state_ptrs);
+
+       x = __xfrm_state_lookup_exact(&state_ptrs, mark, daddr, spi, proto, 
family);
+       rcu_read_unlock();
+       return x;
+}
+EXPORT_SYMBOL(xfrm_state_lookup_exact);
+
 struct xfrm_state *
 xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 
reqid,
              u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr,
diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c
index 6384795ee6b2..b56fca666b89 100644
--- a/net/xfrm/xfrm_user.c
+++ b/net/xfrm/xfrm_user.c
@@ -1089,11 +1089,12 @@ static struct xfrm_state *xfrm_user_state_lookup(struct 
net *net,
        struct xfrm_state *x = NULL;
        struct xfrm_mark m;
        int err;
-       u32 mark = xfrm_mark_get(attrs, &m);
+
+       xfrm_mark_get(attrs, &m);
 
        if (xfrm_id_proto_match(p->proto, IPSEC_PROTO_ANY)) {
                err = -ESRCH;
-               x = xfrm_state_lookup(net, mark, &p->daddr, p->spi, p->proto, 
p->family);
+               x = xfrm_state_lookup_exact(net, &m, &p->daddr, p->spi, 
p->proto, p->family);
        } else {
                xfrm_address_t *saddr = NULL;
 
@@ -1104,7 +1105,7 @@ static struct xfrm_state *xfrm_user_state_lookup(struct 
net *net,
                }
 
                err = -ESRCH;
-               x = xfrm_state_lookup_byaddr(net, mark,
+               x = xfrm_state_lookup_byaddr(net, m.v & m.m,
                                             &p->daddr, saddr,
                                             p->proto, p->family);
        }
@@ -2788,14 +2789,13 @@ static int xfrm_get_ae(struct sk_buff *skb, struct 
nlmsghdr *nlh,
        struct sk_buff *r_skb;
        int err;
        struct km_event c;
-       u32 mark;
        struct xfrm_mark m;
        struct xfrm_aevent_id *p = nlmsg_data(nlh);
        struct xfrm_usersa_id *id = &p->sa_id;
 
-       mark = xfrm_mark_get(attrs, &m);
+       xfrm_mark_get(attrs, &m);
 
-       x = xfrm_state_lookup(net, mark, &id->daddr, id->spi, id->proto, 
id->family);
+       x = xfrm_state_lookup_exact(net, &m, &id->daddr, id->spi, id->proto, 
id->family);
        if (x == NULL)
                return -ESRCH;
 
@@ -2836,7 +2836,6 @@ static int xfrm_new_ae(struct sk_buff *skb, struct 
nlmsghdr *nlh,
        struct xfrm_state *x;
        struct km_event c;
        int err = -EINVAL;
-       u32 mark = 0;
        struct xfrm_mark m;
        struct xfrm_aevent_id *p = nlmsg_data(nlh);
        struct nlattr *rp = attrs[XFRMA_REPLAY_VAL];
@@ -2856,9 +2855,10 @@ static int xfrm_new_ae(struct sk_buff *skb, struct 
nlmsghdr *nlh,
                return err;
        }
 
-       mark = xfrm_mark_get(attrs, &m);
+       xfrm_mark_get(attrs, &m);
 
-       x = xfrm_state_lookup(net, mark, &p->sa_id.daddr, p->sa_id.spi, 
p->sa_id.proto, p->sa_id.family);
+       x = xfrm_state_lookup_exact(net, &m, &p->sa_id.daddr, p->sa_id.spi,
+                                   p->sa_id.proto, p->sa_id.family);
        if (x == NULL)
                return -ESRCH;
 
@@ -2992,9 +2992,10 @@ static int xfrm_add_sa_expire(struct sk_buff *skb, 
struct nlmsghdr *nlh,
        struct xfrm_user_expire *ue = nlmsg_data(nlh);
        struct xfrm_usersa_info *p = &ue->state;
        struct xfrm_mark m;
-       u32 mark = xfrm_mark_get(attrs, &m);
 
-       x = xfrm_state_lookup(net, mark, &p->id.daddr, p->id.spi, p->id.proto, 
p->family);
+       xfrm_mark_get(attrs, &m);
+
+       x = xfrm_state_lookup_exact(net, &m, &p->id.daddr, p->id.spi, 
p->id.proto, p->family);
 
        err = -ENOENT;
        if (x == NULL)
@@ -3361,9 +3362,9 @@ static int xfrm_do_migrate_state(struct sk_buff *skb, 
struct nlmsghdr *nlh,
 
        copy_from_user_migrate_state(&m, um);
 
-       x = xfrm_state_lookup(net, m.old_mark.v & m.old_mark.m,
-                             &um->id.daddr, um->id.spi,
-                             um->id.proto, um->id.family);
+       x = xfrm_state_lookup_exact(net, &m.old_mark,
+                                   &um->id.daddr, um->id.spi,
+                                   um->id.proto, um->id.family);
        if (!x) {
                NL_SET_ERR_MSG(extack, "Can not find state");
                return -ESRCH;

-- 
2.47.3


Reply via email to