Bonding now updates its header_ops callbacks at runtime, so lockless
readers can observe concurrent callback updates.

This patch loads header_ops callbacks with READ_ONCE() and 
call the loaded function pointer, instead of 
re-reading through dev->header_ops.

Signed-off-by: Kota Toda <[email protected]>
Co-developed-by: Yuki Koike <[email protected]>
Signed-off-by: Yuki Koike <[email protected]>
---
 include/linux/netdevice.h | 41 +++++++++++++++++++++++++++------------
 include/net/cfg802154.h   |  2 +-
 net/core/neighbour.c      |  6 +++---
 net/ipv4/arp.c            |  2 +-
 net/ipv6/ndisc.c          |  2 +-
 5 files changed, 35 insertions(+), 18 deletions(-)

diff --git a/include/linux/netdevice.h b/include/linux/netdevice.h
index 77a99c8ab..79fb0864a 100644
--- a/include/linux/netdevice.h
+++ b/include/linux/netdevice.h
@@ -3150,35 +3150,50 @@ static inline int dev_hard_header(struct sk_buff *skb, 
struct net_device *dev,
                                  const void *daddr, const void *saddr,
                                  unsigned int len)
 {
-       if (!dev->header_ops || !dev->header_ops->create)
-               return 0;
+       int (*create)(struct sk_buff *skb, struct net_device *dev,
+                     unsigned short type, const void *daddr,
+                     const void *saddr, unsigned int len);
 
-       return dev->header_ops->create(skb, dev, type, daddr, saddr, len);
+       if (!dev->header_ops)
+               return 0;
+       create = READ_ONCE(dev->header_ops->create);
+       if (!create)
+               return 0;
+       return create(skb, dev, type, daddr, saddr, len);
 }
 
 static inline int dev_parse_header(const struct sk_buff *skb,
                                   unsigned char *haddr)
 {
+       int (*parse)(const struct sk_buff *skb, unsigned char *haddr);
        const struct net_device *dev = skb->dev;
 
-       if (!dev->header_ops || !dev->header_ops->parse)
+       if (!dev->header_ops)
                return 0;
-       return dev->header_ops->parse(skb, haddr);
+       parse = READ_ONCE(dev->header_ops->parse);
+       if (!parse)
+               return 0;
+       return parse(skb, haddr);
 }
 
 static inline __be16 dev_parse_header_protocol(const struct sk_buff *skb)
 {
+       __be16  (*parse_protocol)(const struct sk_buff *skb);
        const struct net_device *dev = skb->dev;
 
-       if (!dev->header_ops || !dev->header_ops->parse_protocol)
+       if (!dev->header_ops)
+               return 0;
+       parse_protocol = READ_ONCE(dev->header_ops->parse_protocol);
+       if (!parse_protocol)
                return 0;
-       return dev->header_ops->parse_protocol(skb);
+       return parse_protocol(skb);
 }
 
 /* ll_header must have at least hard_header_len allocated */
 static inline bool dev_validate_header(const struct net_device *dev,
                                       char *ll_header, int len)
 {
+       bool    (*validate)(const char *ll_header, unsigned int len);
        if (likely(len >= dev->hard_header_len))
                return true;
        if (len < dev->min_header_len)
@@ -3189,15 +3204,17 @@ static inline bool dev_validate_header(const struct 
net_device *dev,
                return true;
        }
 
-       if (dev->header_ops && dev->header_ops->validate)
-               return dev->header_ops->validate(ll_header, len);
-
-       return false;
+       if (!dev->header_ops)
+               return false;
+       validate = READ_ONCE(dev->header_ops->validate);
+       if (!validate)
+               return false;
+       return validate(ll_header, len);
 }
 
 static inline bool dev_has_header(const struct net_device *dev)
 {
-       return dev->header_ops && dev->header_ops->create;
+       return dev->header_ops && READ_ONCE(dev->header_ops->create);
 }
 
 /*
diff --git a/include/net/cfg802154.h b/include/net/cfg802154.h
index 76d2cd2e2..dec638763 100644
--- a/include/net/cfg802154.h
+++ b/include/net/cfg802154.h
@@ -522,7 +522,7 @@ wpan_dev_hard_header(struct sk_buff *skb, struct net_device 
*dev,
 {
        struct wpan_dev *wpan_dev = dev->ieee802154_ptr;
 
-       return wpan_dev->header_ops->create(skb, dev, daddr, saddr, len);
+       return READ_ONCE(wpan_dev->header_ops->create)(skb, dev, daddr, saddr, 
len);
 }
 #endif
 
diff --git a/net/core/neighbour.c b/net/core/neighbour.c
index 96786016d..ff948e35e 100644
--- a/net/core/neighbour.c
+++ b/net/core/neighbour.c
@@ -1270,7 +1270,7 @@ static void neigh_update_hhs(struct neighbour *neigh)
                = NULL;
 
        if (neigh->dev->header_ops)
-               update = neigh->dev->header_ops->cache_update;
+               update = READ_ONCE(neigh->dev->header_ops->cache_update);
 
        if (update) {
                hh = &neigh->hh;
@@ -1540,7 +1540,7 @@ static void neigh_hh_init(struct neighbour *n)
         * hh_cache entry.
         */
        if (!hh->hh_len)
-               dev->header_ops->cache(n, hh, prot);
+               READ_ONCE(dev->header_ops->cache)(n, hh, prot);
 
        write_unlock_bh(&n->lock);
 }
@@ -1556,7 +1556,7 @@ int neigh_resolve_output(struct neighbour *neigh, struct 
sk_buff *skb)
                struct net_device *dev = neigh->dev;
                unsigned int seq;
 
-               if (dev->header_ops->cache && !READ_ONCE(neigh->hh.hh_len))
+               if (READ_ONCE(dev->header_ops->cache) && 
!READ_ONCE(neigh->hh.hh_len))
                        neigh_hh_init(neigh);
 
                do {
diff --git a/net/ipv4/arp.c b/net/ipv4/arp.c
index 7822b2144..421bea6eb 100644
--- a/net/ipv4/arp.c
+++ b/net/ipv4/arp.c
@@ -278,7 +278,7 @@ static int arp_constructor(struct neighbour *neigh)
                        memcpy(neigh->ha, dev->broadcast, dev->addr_len);
                }
 
-               if (dev->header_ops->cache)
+               if (READ_ONCE(dev->header_ops->cache))
                        neigh->ops = &arp_hh_ops;
                else
                        neigh->ops = &arp_generic_ops;
diff --git a/net/ipv6/ndisc.c b/net/ipv6/ndisc.c
index d961e6c2d..d81f509ec 100644
--- a/net/ipv6/ndisc.c
+++ b/net/ipv6/ndisc.c
@@ -361,7 +361,7 @@ static int ndisc_constructor(struct neighbour *neigh)
                        neigh->nud_state = NUD_NOARP;
                        memcpy(neigh->ha, dev->broadcast, dev->addr_len);
                }
-               if (dev->header_ops->cache)
+               if (READ_ONCE(dev->header_ops->cache))
                        neigh->ops = &ndisc_hh_ops;
                else
                        neigh->ops = &ndisc_generic_ops;
-- 
2.53.0


Reply via email to