Use READ_ONCE when loading header_ops callbacks to avoid races with concurrent updates.
Signed-off-by: Kota Toda <[email protected]> Co-developed-by: Yuki Koike <[email protected]> Signed-off-by: Yuki Koike <[email protected]> --- include/linux/netdevice.h | 28 +++++++++++++++++++--------- include/net/cfg802154.h | 2 +- net/core/neighbour.c | 6 +++--- net/ipv4/arp.c | 2 +- net/ipv6/ndisc.c | 2 +- 5 files changed, 25 insertions(+), 15 deletions(-) diff --git a/include/linux/netdevice.h b/include/linux/netdevice.h index 77a99c8ab..f50b0a4e8 100644 --- a/include/linux/netdevice.h +++ b/include/linux/netdevice.h @@ -3150,35 +3150,44 @@ static inline int dev_hard_header(struct sk_buff *skb, struct net_device *dev, const void *daddr, const void *saddr, unsigned int len) { - if (!dev->header_ops || !dev->header_ops->create) + int (*create)(struct sk_buff *skb, struct net_device *dev, + unsigned short type, const void *daddr, + const void *saddr, unsigned int len); + create = READ_ONCE(dev->header_ops->create); + if (!dev->header_ops || !create) return 0; - return dev->header_ops->create(skb, dev, type, daddr, saddr, len); + return create(skb, dev, type, daddr, saddr, len); } static inline int dev_parse_header(const struct sk_buff *skb, unsigned char *haddr) { + int (*parse)(const struct sk_buff *skb, unsigned char *haddr); const struct net_device *dev = skb->dev; - if (!dev->header_ops || !dev->header_ops->parse) + parse = READ_ONCE(dev->header_ops->parse); + if (!dev->header_ops || !parse) return 0; - return dev->header_ops->parse(skb, haddr); + return parse(skb, haddr); } static inline __be16 dev_parse_header_protocol(const struct sk_buff *skb) { + __be16 (*parse_protocol)(const struct sk_buff *skb); const struct net_device *dev = skb->dev; - if (!dev->header_ops || !dev->header_ops->parse_protocol) + parse_protocol = READ_ONCE(dev->header_ops->parse_protocol); + if (!dev->header_ops || !parse_protocol) return 0; - return dev->header_ops->parse_protocol(skb); + return parse_protocol(skb); } /* ll_header must have at least hard_header_len allocated */ static inline bool dev_validate_header(const struct net_device *dev, char *ll_header, int len) { + bool (*validate)(const char *ll_header, unsigned int len); if (likely(len >= dev->hard_header_len)) return true; if (len < dev->min_header_len) @@ -3189,15 +3198,16 @@ static inline bool dev_validate_header(const struct net_device *dev, return true; } - if (dev->header_ops && dev->header_ops->validate) - return dev->header_ops->validate(ll_header, len); + validate = READ_ONCE(dev->header_ops->validate); + if (dev->header_ops && validate) + return validate(ll_header, len); return false; } static inline bool dev_has_header(const struct net_device *dev) { - return dev->header_ops && dev->header_ops->create; + return dev->header_ops && READ_ONCE(dev->header_ops->create); } /* diff --git a/include/net/cfg802154.h b/include/net/cfg802154.h index 76d2cd2e2..dec638763 100644 --- a/include/net/cfg802154.h +++ b/include/net/cfg802154.h @@ -522,7 +522,7 @@ wpan_dev_hard_header(struct sk_buff *skb, struct net_device *dev, { struct wpan_dev *wpan_dev = dev->ieee802154_ptr; - return wpan_dev->header_ops->create(skb, dev, daddr, saddr, len); + return READ_ONCE(wpan_dev->header_ops->create)(skb, dev, daddr, saddr, len); } #endif diff --git a/net/core/neighbour.c b/net/core/neighbour.c index 96786016d..ff948e35e 100644 --- a/net/core/neighbour.c +++ b/net/core/neighbour.c @@ -1270,7 +1270,7 @@ static void neigh_update_hhs(struct neighbour *neigh) = NULL; if (neigh->dev->header_ops) - update = neigh->dev->header_ops->cache_update; + update = READ_ONCE(neigh->dev->header_ops->cache_update); if (update) { hh = &neigh->hh; @@ -1540,7 +1540,7 @@ static void neigh_hh_init(struct neighbour *n) * hh_cache entry. */ if (!hh->hh_len) - dev->header_ops->cache(n, hh, prot); + READ_ONCE(dev->header_ops->cache)(n, hh, prot); write_unlock_bh(&n->lock); } @@ -1556,7 +1556,7 @@ int neigh_resolve_output(struct neighbour *neigh, struct sk_buff *skb) struct net_device *dev = neigh->dev; unsigned int seq; - if (dev->header_ops->cache && !READ_ONCE(neigh->hh.hh_len)) + if (READ_ONCE(dev->header_ops->cache) && !READ_ONCE(neigh->hh.hh_len)) neigh_hh_init(neigh); do { diff --git a/net/ipv4/arp.c b/net/ipv4/arp.c index 7822b2144..421bea6eb 100644 --- a/net/ipv4/arp.c +++ b/net/ipv4/arp.c @@ -278,7 +278,7 @@ static int arp_constructor(struct neighbour *neigh) memcpy(neigh->ha, dev->broadcast, dev->addr_len); } - if (dev->header_ops->cache) + if (READ_ONCE(dev->header_ops->cache)) neigh->ops = &arp_hh_ops; else neigh->ops = &arp_generic_ops; diff --git a/net/ipv6/ndisc.c b/net/ipv6/ndisc.c index d961e6c2d..d81f509ec 100644 --- a/net/ipv6/ndisc.c +++ b/net/ipv6/ndisc.c @@ -361,7 +361,7 @@ static int ndisc_constructor(struct neighbour *neigh) neigh->nud_state = NUD_NOARP; memcpy(neigh->ha, dev->broadcast, dev->addr_len); } - if (dev->header_ops->cache) + if (READ_ONCE(dev->header_ops->cache)) neigh->ops = &ndisc_hh_ops; else neigh->ops = &ndisc_generic_ops; -- 2.53.0

