This is achieved by propagating the psp_dev from the lower device to the upper devices in the device stack via a netdevice notifier. The lowest device owns the psp_dev pointer while the upper devices just borrow the pointer. When the lower device is unlinked, the borrowed pointer is cleared in the upper device. Assumption being that psp_dev is set on the lowest device before any upper devices are stacked on that lowest device.
Signed-off-by: Kiran Kella <[email protected]> Reviewed-by: Ajit Kumar Khaparde <[email protected]> Reviewed-by: Akhilesh Samineni <[email protected]> --- net/psp/psp_main.c | 117 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 1 deletion(-) diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c index a8534124f626..303a3af3bdf2 100644 --- a/net/psp/psp_main.c +++ b/net/psp/psp_main.c @@ -3,6 +3,7 @@ #include <linux/bitfield.h> #include <linux/list.h> #include <linux/netdevice.h> +#include <linux/rtnetlink.h> #include <linux/xarray.h> #include <net/net_namespace.h> #include <net/psp.h> @@ -110,13 +111,45 @@ void psp_dev_free(struct psp_dev *psd) kfree_rcu(psd, rcu); } +/** + * psp_clear_upper_dev_psp_dev() - Clear borrowed psp_dev pointer on upper + * device + * @upper_dev: Upper device that may have borrowed psp_dev pointer + * @priv: netdev_nested_priv containing the psp_dev being unregistered + * + * Callback for netdev_walk_all_upper_dev_rcu() to clear borrowed psp_dev + * pointers on upper devices when the underlying psp_dev is being unregistered. + * + * Return: 0 to continue walking, non-zero to stop. + */ +static int psp_clear_upper_dev_psp_dev(struct net_device *upper_dev, + struct netdev_nested_priv *priv) +{ + struct psp_dev *psd = priv->data; + struct psp_dev *upper_psd; + + upper_psd = rcu_dereference(upper_dev->psp_dev); + if (upper_psd == psd) + rcu_assign_pointer(upper_dev->psp_dev, NULL); + + return 0; +} + /** * psp_dev_unregister() - unregister PSP device * @psd: PSP device structure + * + * Unregisters a PSP device and clears all borrowed psp_dev pointers on + * upper devices (e.g., VLAN subinterfaces) that reference this device. + * This prevents use-after-free if upper devices still have borrowed + * pointers when the psp_dev structure is freed. */ void psp_dev_unregister(struct psp_dev *psd) { struct psp_assoc *pas, *next; + struct netdev_nested_priv priv = { + .data = psd, + }; mutex_lock(&psp_devs_lock); mutex_lock(&psd->lock); @@ -137,6 +170,12 @@ void psp_dev_unregister(struct psp_dev *psd) rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); + /* Clear borrowed psp_dev pointers on all upper devices */ + rcu_read_lock(); + netdev_walk_all_upper_dev_rcu(psd->main_netdev, + psp_clear_upper_dev_psp_dev, &priv); + rcu_read_unlock(); + psd->ops = NULL; psd->drv_priv = NULL; @@ -313,11 +352,87 @@ int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv) } EXPORT_SYMBOL(psp_dev_rcv); +/** + * psp_netdevice_event() - Handle netdevice events for PSP device propagation + * @nb: notifier block + * @event: netdevice event + * @ptr: netdevice notifier info + * + * Propagates psp_dev pointer from lower devices to upper devices when + * upper devices are created (e.g., VLAN subinterfaces). + * Excluding from this logic, the upper devices that have multiple lower + * devices eg., bond devices. + * + * Return: NOTIFY_DONE + */ +static int psp_netdevice_event(struct notifier_block *nb, + unsigned long event, void *ptr) +{ + struct netdev_notifier_changeupper_info *info; + struct net_device *dev, *upper_dev; + struct psp_dev *psd; + + if (event != NETDEV_CHANGEUPPER) + return NOTIFY_DONE; + + info = ptr; + dev = netdev_notifier_info_to_dev(ptr); + upper_dev = info->upper_dev; + + if (netif_is_lag_master(upper_dev)) + return NOTIFY_DONE; + + if (info->linking) { + /* Lower device is being linked to an upper device. + * Propagate psp_dev from the immediate lower device to the + * upper device. The immediate lower device would have already + * got the psp_dev pointer set in a previous notification (or + * owns it if it's the lowest device). + * Upper devices just borrow the pointer. + * + * Use psp_dev_tryget/put to guard against a concurrent + * psp_dev_unregister that may have already cleared the + * lower device's pointer and started freeing the psp_dev. + */ + psd = rtnl_dereference(dev->psp_dev); + if (psd && psp_dev_tryget(psd)) { + rcu_assign_pointer(upper_dev->psp_dev, psd); + psp_dev_put(psd); + } + } else { + /* Lower device is being unlinked from an upper device. + * Clear the borrowed psp_dev pointer on the upper device. + * Any devices stacked above upper_dev will get their own + * NETDEV_CHANGEUPPER notifications as the stack unwinds. + */ + psd = rtnl_dereference(upper_dev->psp_dev); + if (psd) + rcu_assign_pointer(upper_dev->psp_dev, NULL); + } + + return NOTIFY_DONE; +} + +static struct notifier_block psp_netdevice_notifier = { + .notifier_call = psp_netdevice_event, +}; + static int __init psp_init(void) { + int err; + mutex_init(&psp_devs_lock); - return genl_register_family(&psp_nl_family); + err = register_netdevice_notifier(&psp_netdevice_notifier); + if (err) + return err; + + err = genl_register_family(&psp_nl_family); + if (err) { + unregister_netdevice_notifier(&psp_netdevice_notifier); + return err; + } + return 0; } subsys_initcall(psp_init); -- 2.45.4

