This is achieved by propagating the psp_dev from the lower device
to the upper devices in the device stack via a netdevice notifier.
The lowest device owns the psp_dev pointer while the upper devices
just borrow the pointer. When the lower device is unlinked, the
borrowed pointer is cleared in the upper device.
Assumption being that psp_dev is set on the lowest device before
any upper devices are stacked on that lowest device.

Signed-off-by: Kiran Kella <[email protected]>
Reviewed-by: Ajit Kumar Khaparde <[email protected]>
Reviewed-by: Akhilesh Samineni <[email protected]>
---
 net/psp/psp_main.c | 117 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 116 insertions(+), 1 deletion(-)

diff --git a/net/psp/psp_main.c b/net/psp/psp_main.c
index a8534124f626..303a3af3bdf2 100644
--- a/net/psp/psp_main.c
+++ b/net/psp/psp_main.c
@@ -3,6 +3,7 @@
 #include <linux/bitfield.h>
 #include <linux/list.h>
 #include <linux/netdevice.h>
+#include <linux/rtnetlink.h>
 #include <linux/xarray.h>
 #include <net/net_namespace.h>
 #include <net/psp.h>
@@ -110,13 +111,45 @@ void psp_dev_free(struct psp_dev *psd)
        kfree_rcu(psd, rcu);
 }
 
+/**
+ * psp_clear_upper_dev_psp_dev() - Clear borrowed psp_dev pointer on upper
+ *                                device
+ * @upper_dev: Upper device that may have borrowed psp_dev pointer
+ * @priv:      netdev_nested_priv containing the psp_dev being unregistered
+ *
+ * Callback for netdev_walk_all_upper_dev_rcu() to clear borrowed psp_dev
+ * pointers on upper devices when the underlying psp_dev is being unregistered.
+ *
+ * Return: 0 to continue walking, non-zero to stop.
+ */
+static int psp_clear_upper_dev_psp_dev(struct net_device *upper_dev,
+                                      struct netdev_nested_priv *priv)
+{
+       struct psp_dev *psd = priv->data;
+       struct psp_dev *upper_psd;
+
+       upper_psd = rcu_dereference(upper_dev->psp_dev);
+       if (upper_psd == psd)
+               rcu_assign_pointer(upper_dev->psp_dev, NULL);
+
+       return 0;
+}
+
 /**
  * psp_dev_unregister() - unregister PSP device
  * @psd:       PSP device structure
+ *
+ * Unregisters a PSP device and clears all borrowed psp_dev pointers on
+ * upper devices (e.g., VLAN subinterfaces) that reference this device.
+ * This prevents use-after-free if upper devices still have borrowed
+ * pointers when the psp_dev structure is freed.
  */
 void psp_dev_unregister(struct psp_dev *psd)
 {
        struct psp_assoc *pas, *next;
+       struct netdev_nested_priv priv = {
+               .data = psd,
+       };
 
        mutex_lock(&psp_devs_lock);
        mutex_lock(&psd->lock);
@@ -137,6 +170,12 @@ void psp_dev_unregister(struct psp_dev *psd)
 
        rcu_assign_pointer(psd->main_netdev->psp_dev, NULL);
 
+       /* Clear borrowed psp_dev pointers on all upper devices */
+       rcu_read_lock();
+       netdev_walk_all_upper_dev_rcu(psd->main_netdev,
+                                     psp_clear_upper_dev_psp_dev, &priv);
+       rcu_read_unlock();
+
        psd->ops = NULL;
        psd->drv_priv = NULL;
 
@@ -313,11 +352,87 @@ int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 
generation, bool strip_icv)
 }
 EXPORT_SYMBOL(psp_dev_rcv);
 
+/**
+ * psp_netdevice_event() - Handle netdevice events for PSP device propagation
+ * @nb:                notifier block
+ * @event:     netdevice event
+ * @ptr:       netdevice notifier info
+ *
+ * Propagates psp_dev pointer from lower devices to upper devices when
+ * upper devices are created (e.g., VLAN subinterfaces).
+ * Excluding from this logic, the upper devices that have multiple lower
+ * devices eg., bond devices.
+ *
+ * Return: NOTIFY_DONE
+ */
+static int psp_netdevice_event(struct notifier_block *nb,
+                              unsigned long event, void *ptr)
+{
+       struct netdev_notifier_changeupper_info *info;
+       struct net_device *dev, *upper_dev;
+       struct psp_dev *psd;
+
+       if (event != NETDEV_CHANGEUPPER)
+               return NOTIFY_DONE;
+
+       info = ptr;
+       dev = netdev_notifier_info_to_dev(ptr);
+       upper_dev = info->upper_dev;
+
+       if (netif_is_lag_master(upper_dev))
+               return NOTIFY_DONE;
+
+       if (info->linking) {
+               /* Lower device is being linked to an upper device.
+                * Propagate psp_dev from the immediate lower device to the
+                * upper device. The immediate lower device would have already
+                * got the psp_dev pointer set in a previous notification (or
+                * owns it if it's the lowest device).
+                * Upper devices just borrow the pointer.
+                *
+                * Use psp_dev_tryget/put to guard against a concurrent
+                * psp_dev_unregister that may have already cleared the
+                * lower device's pointer and started freeing the psp_dev.
+                */
+               psd = rtnl_dereference(dev->psp_dev);
+               if (psd && psp_dev_tryget(psd)) {
+                       rcu_assign_pointer(upper_dev->psp_dev, psd);
+                       psp_dev_put(psd);
+               }
+       } else {
+               /* Lower device is being unlinked from an upper device.
+                * Clear the borrowed psp_dev pointer on the upper device.
+                * Any devices stacked above upper_dev will get their own
+                * NETDEV_CHANGEUPPER notifications as the stack unwinds.
+                */
+               psd = rtnl_dereference(upper_dev->psp_dev);
+               if (psd)
+                       rcu_assign_pointer(upper_dev->psp_dev, NULL);
+       }
+
+       return NOTIFY_DONE;
+}
+
+static struct notifier_block psp_netdevice_notifier = {
+       .notifier_call = psp_netdevice_event,
+};
+
 static int __init psp_init(void)
 {
+       int err;
+
        mutex_init(&psp_devs_lock);
 
-       return genl_register_family(&psp_nl_family);
+       err = register_netdevice_notifier(&psp_netdevice_notifier);
+       if (err)
+               return err;
+
+       err = genl_register_family(&psp_nl_family);
+       if (err) {
+               unregister_netdevice_notifier(&psp_netdevice_notifier);
+               return err;
+       }
+       return 0;
 }
 
 subsys_initcall(psp_init);
-- 
2.45.4


Reply via email to