Add __xfrm_state_lookup_exact(), an identity-match like
__xfrm_state_lookup().
Wire it into every SPI-keyed control-plane path: DELSA/GETSA, UPDSA,
GETAE/NEWAE, EXPIRE, MIGRATE_STATE.
xfrm_state_add()'s duplicate-detect keeps the wildcard
__xfrm_state_locate() - unrelated, unchanged.
Fixes: 3d6acfa7641f ("xfrm: SA lookups with mark")
Signed-off-by: Antony Antony <[email protected]>
---
include/net/xfrm.h | 3 ++
net/xfrm/xfrm_state.c | 80 ++++++++++++++++++++++++++++++++++++++++++++-------
net/xfrm/xfrm_user.c | 29 ++++++++++---------
3 files changed, 88 insertions(+), 24 deletions(-)
diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index 519a0156a05c..f6ed590cb2ff 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -1746,6 +1746,9 @@ struct xfrm_state *xfrm_state_lookup_byaddr(struct net
*net, u32 mark,
const xfrm_address_t *saddr,
u8 proto,
unsigned short family);
+struct xfrm_state *xfrm_state_lookup_exact(struct net *net, const struct
xfrm_mark *mark,
+ const xfrm_address_t *daddr, __be32
spi,
+ u8 proto, unsigned short family);
#ifdef CONFIG_XFRM_SUB_POLICY
void xfrm_tmpl_sort(struct xfrm_tmpl **dst, struct xfrm_tmpl **src, int n,
unsigned short family);
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index c58cd024e3c6..df761ce1c290 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -1172,11 +1172,22 @@ static struct xfrm_state *__xfrm_state_lookup_all(const
struct xfrm_hash_state_p
return NULL;
}
-static struct xfrm_state *__xfrm_state_lookup(const struct
xfrm_hash_state_ptrs *state_ptrs,
- u32 mark,
- const xfrm_address_t *daddr,
- __be32 spi, u8 proto,
- unsigned short family)
+/* exact=false: data-plane wildcard match against x's mask. exact=true:
+ * control-plane identity match, mark and mask must both match exactly.
+ */
+static bool xfrm_state_mark_matches(const struct xfrm_state *x, u32 mark, u32
mask, bool exact)
+{
+ if (exact)
+ return x->mark.v == mark && x->mark.m == mask;
+ return (mark & x->mark.m) == x->mark.v;
+}
+
+static struct xfrm_state *
+__xfrm_state_lookup(const struct xfrm_hash_state_ptrs *state_ptrs,
+ u32 mark, u32 mask, bool exact,
+ const xfrm_address_t *daddr,
+ __be32 spi, u8 proto,
+ unsigned short family)
{
unsigned int h = __xfrm_spi_hash(daddr, spi, proto, family,
state_ptrs->hmask);
struct xfrm_state *x;
@@ -1188,7 +1199,7 @@ static struct xfrm_state *__xfrm_state_lookup(const
struct xfrm_hash_state_ptrs
!xfrm_addr_equal(&x->id.daddr, daddr, family))
continue;
- if ((mark & x->mark.m) != x->mark.v)
+ if (!xfrm_state_mark_matches(x, mark, mask, exact))
continue;
if (!xfrm_state_hold_rcu(x))
continue;
@@ -1198,6 +1209,17 @@ static struct xfrm_state *__xfrm_state_lookup(const
struct xfrm_hash_state_ptrs
return NULL;
}
+static struct xfrm_state *
+__xfrm_state_lookup_exact(const struct xfrm_hash_state_ptrs *state_ptrs,
+ const struct xfrm_mark *mark,
+ const xfrm_address_t *daddr,
+ __be32 spi, u8 proto,
+ unsigned short family)
+{
+ return __xfrm_state_lookup(state_ptrs, mark->v, mark->m, true,
+ daddr, spi, proto, family);
+}
+
struct xfrm_state *xfrm_input_state_lookup(struct net *net, u32 mark,
const xfrm_address_t *daddr,
__be32 spi, u8 proto,
@@ -1228,7 +1250,7 @@ struct xfrm_state *xfrm_input_state_lookup(struct net
*net, u32 mark,
xfrm_hash_ptrs_get(net, &state_ptrs);
- x = __xfrm_state_lookup(&state_ptrs, mark, daddr, spi, proto, family);
+ x = __xfrm_state_lookup(&state_ptrs, mark, 0, false, daddr, spi, proto,
family);
if (x) {
spin_lock(&net->xfrm.xfrm_state_lock);
if (x->km.state != XFRM_STATE_VALID) {
@@ -1288,7 +1310,7 @@ __xfrm_state_locate(struct xfrm_state *x, int use_spi,
int family)
xfrm_hash_ptrs_get(net, &state_ptrs);
if (use_spi)
- return __xfrm_state_lookup(&state_ptrs, mark, &x->id.daddr,
+ return __xfrm_state_lookup(&state_ptrs, mark, 0, false,
&x->id.daddr,
x->id.spi, x->id.proto, family);
else
return __xfrm_state_lookup_byaddr(&state_ptrs, mark,
@@ -1297,6 +1319,27 @@ __xfrm_state_locate(struct xfrm_state *x, int use_spi,
int family)
x->id.proto, family);
}
+/* Used by xfrm_state_update() only; xfrm_state_add()'s dup check keeps
+ * using the wildcard __xfrm_state_locate() above.
+ */
+static inline struct xfrm_state *
+__xfrm_state_locate_exact(struct xfrm_state *x, int use_spi, int family)
+{
+ struct xfrm_hash_state_ptrs state_ptrs;
+ struct net *net = xs_net(x);
+
+ xfrm_hash_ptrs_get(net, &state_ptrs);
+
+ if (use_spi)
+ return __xfrm_state_lookup_exact(&state_ptrs, &x->mark,
&x->id.daddr,
+ x->id.spi, x->id.proto,
family);
+ else
+ return __xfrm_state_lookup_byaddr(&state_ptrs, x->mark.v &
x->mark.m,
+ &x->id.daddr,
+ &x->props.saddr,
+ x->id.proto, family);
+}
+
static void xfrm_hash_grow_check(struct net *net, int have_hash_collision)
{
if (have_hash_collision &&
@@ -2229,7 +2272,7 @@ int xfrm_state_update(struct xfrm_state *x)
to_put = NULL;
spin_lock_bh(&net->xfrm.xfrm_state_lock);
- x1 = __xfrm_state_locate(x, use_spi, x->props.family);
+ x1 = __xfrm_state_locate_exact(x, use_spi, x->props.family);
err = -ESRCH;
if (!x1)
@@ -2374,7 +2417,7 @@ xfrm_state_lookup(struct net *net, u32 mark, const
xfrm_address_t *daddr, __be32
rcu_read_lock();
xfrm_hash_ptrs_get(net, &state_ptrs);
- x = __xfrm_state_lookup(&state_ptrs, mark, daddr, spi, proto, family);
+ x = __xfrm_state_lookup(&state_ptrs, mark, 0, false, daddr, spi, proto,
family);
rcu_read_unlock();
return x;
}
@@ -2398,6 +2441,23 @@ xfrm_state_lookup_byaddr(struct net *net, u32 mark,
}
EXPORT_SYMBOL(xfrm_state_lookup_byaddr);
+struct xfrm_state *
+xfrm_state_lookup_exact(struct net *net, const struct xfrm_mark *mark,
+ const xfrm_address_t *daddr, __be32 spi,
+ u8 proto, unsigned short family)
+{
+ struct xfrm_hash_state_ptrs state_ptrs;
+ struct xfrm_state *x;
+
+ rcu_read_lock();
+ xfrm_hash_ptrs_get(net, &state_ptrs);
+
+ x = __xfrm_state_lookup_exact(&state_ptrs, mark, daddr, spi, proto,
family);
+ rcu_read_unlock();
+ return x;
+}
+EXPORT_SYMBOL(xfrm_state_lookup_exact);
+
struct xfrm_state *
xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32
reqid,
u32 if_id, u32 pcpu_num, u8 proto, const xfrm_address_t *daddr,
diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c
index 6384795ee6b2..b56fca666b89 100644
--- a/net/xfrm/xfrm_user.c
+++ b/net/xfrm/xfrm_user.c
@@ -1089,11 +1089,12 @@ static struct xfrm_state *xfrm_user_state_lookup(struct
net *net,
struct xfrm_state *x = NULL;
struct xfrm_mark m;
int err;
- u32 mark = xfrm_mark_get(attrs, &m);
+
+ xfrm_mark_get(attrs, &m);
if (xfrm_id_proto_match(p->proto, IPSEC_PROTO_ANY)) {
err = -ESRCH;
- x = xfrm_state_lookup(net, mark, &p->daddr, p->spi, p->proto,
p->family);
+ x = xfrm_state_lookup_exact(net, &m, &p->daddr, p->spi,
p->proto, p->family);
} else {
xfrm_address_t *saddr = NULL;
@@ -1104,7 +1105,7 @@ static struct xfrm_state *xfrm_user_state_lookup(struct
net *net,
}
err = -ESRCH;
- x = xfrm_state_lookup_byaddr(net, mark,
+ x = xfrm_state_lookup_byaddr(net, m.v & m.m,
&p->daddr, saddr,
p->proto, p->family);
}
@@ -2788,14 +2789,13 @@ static int xfrm_get_ae(struct sk_buff *skb, struct
nlmsghdr *nlh,
struct sk_buff *r_skb;
int err;
struct km_event c;
- u32 mark;
struct xfrm_mark m;
struct xfrm_aevent_id *p = nlmsg_data(nlh);
struct xfrm_usersa_id *id = &p->sa_id;
- mark = xfrm_mark_get(attrs, &m);
+ xfrm_mark_get(attrs, &m);
- x = xfrm_state_lookup(net, mark, &id->daddr, id->spi, id->proto,
id->family);
+ x = xfrm_state_lookup_exact(net, &m, &id->daddr, id->spi, id->proto,
id->family);
if (x == NULL)
return -ESRCH;
@@ -2836,7 +2836,6 @@ static int xfrm_new_ae(struct sk_buff *skb, struct
nlmsghdr *nlh,
struct xfrm_state *x;
struct km_event c;
int err = -EINVAL;
- u32 mark = 0;
struct xfrm_mark m;
struct xfrm_aevent_id *p = nlmsg_data(nlh);
struct nlattr *rp = attrs[XFRMA_REPLAY_VAL];
@@ -2856,9 +2855,10 @@ static int xfrm_new_ae(struct sk_buff *skb, struct
nlmsghdr *nlh,
return err;
}
- mark = xfrm_mark_get(attrs, &m);
+ xfrm_mark_get(attrs, &m);
- x = xfrm_state_lookup(net, mark, &p->sa_id.daddr, p->sa_id.spi,
p->sa_id.proto, p->sa_id.family);
+ x = xfrm_state_lookup_exact(net, &m, &p->sa_id.daddr, p->sa_id.spi,
+ p->sa_id.proto, p->sa_id.family);
if (x == NULL)
return -ESRCH;
@@ -2992,9 +2992,10 @@ static int xfrm_add_sa_expire(struct sk_buff *skb,
struct nlmsghdr *nlh,
struct xfrm_user_expire *ue = nlmsg_data(nlh);
struct xfrm_usersa_info *p = &ue->state;
struct xfrm_mark m;
- u32 mark = xfrm_mark_get(attrs, &m);
- x = xfrm_state_lookup(net, mark, &p->id.daddr, p->id.spi, p->id.proto,
p->family);
+ xfrm_mark_get(attrs, &m);
+
+ x = xfrm_state_lookup_exact(net, &m, &p->id.daddr, p->id.spi,
p->id.proto, p->family);
err = -ENOENT;
if (x == NULL)
@@ -3361,9 +3362,9 @@ static int xfrm_do_migrate_state(struct sk_buff *skb,
struct nlmsghdr *nlh,
copy_from_user_migrate_state(&m, um);
- x = xfrm_state_lookup(net, m.old_mark.v & m.old_mark.m,
- &um->id.daddr, um->id.spi,
- um->id.proto, um->id.family);
+ x = xfrm_state_lookup_exact(net, &m.old_mark,
+ &um->id.daddr, um->id.spi,
+ um->id.proto, um->id.family);
if (!x) {
NL_SET_ERR_MSG(extack, "Can not find state");
return -ESRCH;
--
2.47.3