Fix several race conditions in ipoib_multicast.c:
1. Make sure mcast->query is set to NULL if, and only if,
   no query is outstanding.
2. Make sure mcast->done is initialized to uncompleted value
   before we submit a new query, so that its safe to wait on.
4. Protect all accesses to priv->broadcast, priv->multicast_list,
   mcast->query and mcast->done by priv->lock.
   I had to change mcast_mutex to ipoib_mcast_lock to make the last bit work.

Signed-off-by: Michael S. Tsirkin <[EMAIL PROTECTED]>

Index: linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib_main.c
===================================================================
--- linux-2.6.14-dbg.orig/drivers/infiniband/ulp/ipoib/ipoib_main.c     
2005-11-20 11:57:00.000000000 +0200
+++ linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib_main.c  2005-11-20 
14:57:18.000000000 +0200
@@ -1146,6 +1146,8 @@ static int __init ipoib_init_module(void
        if (ret)
                goto err_wq;
 
+       spin_lock_init(&ipoib_mcast_lock);
+
        return 0;
 
 err_wq:
Index: linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib_multicast.c
===================================================================
--- linux-2.6.14-dbg.orig/drivers/infiniband/ulp/ipoib/ipoib_multicast.c        
2005-11-20 12:34:04.000000000 +0200
+++ linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib_multicast.c     
2005-11-20 14:57:18.000000000 +0200
@@ -53,7 +53,7 @@ MODULE_PARM_DESC(mcast_debug_level,
                 "Enable multicast debug tracing if > 0");
 #endif
 
-static DECLARE_MUTEX(mcast_mutex);
+spinlock_t ipoib_mcast_lock;
 
 /* Used for all multicast joins (broadcast, IPv4 mcast and IPv6 mcast) */
 struct ipoib_mcast {
@@ -126,17 +126,14 @@ static void ipoib_mcast_free(struct ipoi
        kfree(mcast);
 }
 
-static struct ipoib_mcast *ipoib_mcast_alloc(struct net_device *dev,
-                                            int can_sleep)
+static struct ipoib_mcast *ipoib_mcast_alloc(struct net_device *dev)
 {
        struct ipoib_mcast *mcast;
 
-       mcast = kzalloc(sizeof *mcast, can_sleep ? GFP_KERNEL : GFP_ATOMIC);
+       mcast = kzalloc(sizeof *mcast, GFP_ATOMIC);
        if (!mcast)
                return NULL;
 
-       init_completion(&mcast->done);
-
        mcast->dev = dev;
        mcast->created = jiffies;
        mcast->backoff = 1;
@@ -209,17 +206,23 @@ static int ipoib_mcast_join_finish(struc
 {
        struct net_device *dev = mcast->dev;
        struct ipoib_dev_priv *priv = netdev_priv(dev);
+       unsigned long flags;
        int ret;
 
        mcast->mcmember = *mcmember;
 
+       spin_lock_irqsave(&priv->lock, flags);
+
        /* Set the cached Q_Key before we attach if it's the broadcast group */
-       if (!memcmp(mcast->mcmember.mgid.raw, priv->dev->broadcast + 4,
+       if (priv->broadcast &&
+           !memcmp(mcast->mcmember.mgid.raw, priv->dev->broadcast + 4,
                    sizeof (union ib_gid))) {
                priv->qkey = be32_to_cpu(priv->broadcast->mcmember.qkey);
                priv->tx_wr.wr.ud.remote_qkey = priv->qkey;
        }
 
+       spin_unlock_irqrestore(&priv->lock, flags);
+
        if (!test_bit(IPOIB_MCAST_FLAG_SENDONLY, &mcast->flags)) {
                if (test_and_set_bit(IPOIB_MCAST_FLAG_ATTACHED, &mcast->flags)) 
{
                        ipoib_warn(priv, "multicast group " IPOIB_GID_FMT
@@ -303,6 +306,12 @@ ipoib_mcast_sendonly_join_complete(int s
 {
        struct ipoib_mcast *mcast = mcast_ptr;
        struct net_device *dev = mcast->dev;
+       struct ipoib_dev_priv *priv = netdev_priv(dev);
+       unsigned long flags;
+
+       ipoib_dbg_mcast(priv, "sendonly join completion for " IPOIB_GID_FMT
+                       " (status %d)\n",
+                       IPOIB_GID_ARG(mcast->mcmember.mgid), status);
 
        if (!status)
                ipoib_mcast_join_finish(mcast, mcmember);
@@ -320,7 +329,11 @@ ipoib_mcast_sendonly_join_complete(int s
                clear_bit(IPOIB_MCAST_FLAG_BUSY, &mcast->flags);
        }
 
+       spin_lock_irqsave(&priv->lock, flags);
+       mcast->query = NULL;
+
        complete(&mcast->done);
+       spin_unlock_irqrestore(&priv->lock, flags);
 }
 
 static int ipoib_mcast_sendonly_join(struct ipoib_mcast *mcast)
@@ -350,6 +363,7 @@ static int ipoib_mcast_sendonly_join(str
        rec.port_gid = priv->local_gid;
        rec.pkey     = cpu_to_be16(priv->pkey);
 
+       init_completion(&mcast->done);
        ret = ib_sa_mcmember_rec_set(priv->ca, priv->port, &rec,
                                     IB_SA_MCMEMBER_REC_MGID            |
                                     IB_SA_MCMEMBER_REC_PORT_GID        |
@@ -379,23 +393,31 @@ static void ipoib_mcast_join_complete(in
        struct ipoib_mcast *mcast = mcast_ptr;
        struct net_device *dev = mcast->dev;
        struct ipoib_dev_priv *priv = netdev_priv(dev);
+       unsigned long flags;
 
        ipoib_dbg_mcast(priv, "join completion for " IPOIB_GID_FMT
                        " (status %d)\n",
                        IPOIB_GID_ARG(mcast->mcmember.mgid), status);
 
+
        if (!status && !ipoib_mcast_join_finish(mcast, mcmember)) {
                mcast->backoff = 1;
-               down(&mcast_mutex);
+               spin_lock(&ipoib_mcast_lock);
                if (test_bit(IPOIB_MCAST_RUN, &priv->flags))
                        queue_work(ipoib_workqueue, &priv->mcast_task);
-               up(&mcast_mutex);
+               spin_unlock(&ipoib_mcast_lock);
+               spin_lock_irqsave(&priv->lock, flags);
+               mcast->query = NULL;
                complete(&mcast->done);
+               spin_unlock_irqrestore(&priv->lock, flags);
                return;
        }
 
        if (status == -EINTR) {
+               spin_lock_irqsave(&priv->lock, flags);
+               mcast->query = NULL;
                complete(&mcast->done);
+               spin_unlock_irqrestore(&priv->lock, flags);
                return;
        }
 
@@ -417,20 +439,21 @@ static void ipoib_mcast_join_complete(in
        if (mcast->backoff > IPOIB_MAX_BACKOFF_SECONDS)
                mcast->backoff = IPOIB_MAX_BACKOFF_SECONDS;
 
-       mcast->query = NULL;
+       spin_lock_irqsave(&priv->lock, flags);
 
-       down(&mcast_mutex);
+       spin_lock(&ipoib_mcast_lock);
        if (test_bit(IPOIB_MCAST_RUN, &priv->flags)) {
                if (status == -ETIMEDOUT)
                        queue_work(ipoib_workqueue, &priv->mcast_task);
                else
                        queue_delayed_work(ipoib_workqueue, &priv->mcast_task,
                                           mcast->backoff * HZ);
-       } else
-               complete(&mcast->done);
-       up(&mcast_mutex);
+       }
+       spin_unlock(&ipoib_mcast_lock);
 
-       return;
+       mcast->query = NULL;
+       complete(&mcast->done);
+       spin_unlock_irqrestore(&priv->lock, flags);
 }
 
 static void ipoib_mcast_join(struct net_device *dev, struct ipoib_mcast *mcast,
@@ -469,6 +492,7 @@ static void ipoib_mcast_join(struct net_
                rec.traffic_class = priv->broadcast->mcmember.traffic_class;
        }
 
+       init_completion(&mcast->done);
        ret = ib_sa_mcmember_rec_set(priv->ca, priv->port, &rec, comp_mask,
                                     mcast->backoff * 1000, GFP_ATOMIC,
                                     ipoib_mcast_join_complete,
@@ -481,12 +505,12 @@ static void ipoib_mcast_join(struct net_
                if (mcast->backoff > IPOIB_MAX_BACKOFF_SECONDS)
                        mcast->backoff = IPOIB_MAX_BACKOFF_SECONDS;
 
-               down(&mcast_mutex);
+               spin_lock(&ipoib_mcast_lock);
                if (test_bit(IPOIB_MCAST_RUN, &priv->flags))
                        queue_delayed_work(ipoib_workqueue,
                                           &priv->mcast_task,
                                           mcast->backoff * HZ);
-               up(&mcast_mutex);
+               spin_unlock(&ipoib_mcast_lock);
        } else
                mcast->query_id = ret;
 }
@@ -515,44 +539,44 @@ void ipoib_mcast_join_task(void *dev_ptr
                        ipoib_warn(priv, "ib_query_port failed\n");
        }
 
+       spin_lock_irq(&priv->lock);
+
        if (!priv->broadcast) {
-               priv->broadcast = ipoib_mcast_alloc(dev, 1);
+               priv->broadcast = ipoib_mcast_alloc(dev);
                if (!priv->broadcast) {
                        ipoib_warn(priv, "failed to allocate broadcast 
group\n");
-                       down(&mcast_mutex);
+                       spin_lock(&ipoib_mcast_lock);
                        if (test_bit(IPOIB_MCAST_RUN, &priv->flags))
                                queue_delayed_work(ipoib_workqueue,
                                                   &priv->mcast_task, HZ);
-                       up(&mcast_mutex);
-                       return;
+                       spin_unlock(&ipoib_mcast_lock);
+                       goto unlock;
                }
 
                memcpy(priv->broadcast->mcmember.mgid.raw, priv->dev->broadcast 
+ 4,
                       sizeof (union ib_gid));
 
-               spin_lock_irq(&priv->lock);
                __ipoib_mcast_add(dev, priv->broadcast);
-               spin_unlock_irq(&priv->lock);
        }
 
-       if (!test_bit(IPOIB_MCAST_FLAG_ATTACHED, &priv->broadcast->flags)) {
+       if (!test_bit(IPOIB_MCAST_FLAG_ATTACHED, &priv->broadcast->flags) &&
+           !priv->broadcast->query) {
                ipoib_mcast_join(dev, priv->broadcast, 0);
-               return;
+               goto unlock;
        }
 
        while (1) {
                struct ipoib_mcast *mcast = NULL;
 
-               spin_lock_irq(&priv->lock);
                list_for_each_entry(mcast, &priv->multicast_list, list) {
                        if (!test_bit(IPOIB_MCAST_FLAG_SENDONLY, &mcast->flags)
                            && !test_bit(IPOIB_MCAST_FLAG_BUSY, &mcast->flags)
-                           && !test_bit(IPOIB_MCAST_FLAG_ATTACHED, 
&mcast->flags)) {
+                           && !test_bit(IPOIB_MCAST_FLAG_ATTACHED, 
&mcast->flags)
+                           && !mcast->query) {
                                /* Found the next unjoined group */
                                break;
                        }
                }
-               spin_unlock_irq(&priv->lock);
 
                if (&mcast->list == &priv->multicast_list) {
                        /* All done */
@@ -560,7 +584,7 @@ void ipoib_mcast_join_task(void *dev_ptr
                }
 
                ipoib_mcast_join(dev, mcast, 1);
-               return;
+               goto unlock;
        }
 
        priv->mcast_mtu = ib_mtu_enum_to_int(priv->broadcast->mcmember.mtu) -
@@ -571,48 +595,59 @@ void ipoib_mcast_join_task(void *dev_ptr
 
        clear_bit(IPOIB_MCAST_RUN, &priv->flags);
        netif_carrier_on(dev);
+
+unlock:
+       spin_unlock_irq(&priv->lock);
 }
 
 static void ipoib_mcast_start_thread(struct net_device *dev)
 {
        struct ipoib_dev_priv *priv = netdev_priv(dev);
+       unsigned long flags;
 
        ipoib_dbg_mcast(priv, "starting multicast thread\n");
 
-       down(&mcast_mutex);
+       spin_lock_irqsave(&ipoib_mcast_lock, flags);
        if (!test_and_set_bit(IPOIB_MCAST_RUN, &priv->flags))
                queue_work(ipoib_workqueue, &priv->mcast_task);
-       up(&mcast_mutex);
+       spin_unlock_irqrestore(&ipoib_mcast_lock, flags);
 }
 
 static void ipoib_mcast_stop_thread(struct net_device *dev)
 {
        struct ipoib_dev_priv *priv = netdev_priv(dev);
        struct ipoib_mcast *mcast;
+       unsigned long flags;
 
        ipoib_dbg_mcast(priv, "stopping multicast thread\n");
 
-       down(&mcast_mutex);
+       spin_lock_irqsave(&priv->lock, flags);
+
+       spin_lock(&ipoib_mcast_lock);
        clear_bit(IPOIB_MCAST_RUN, &priv->flags);
        cancel_delayed_work(&priv->mcast_task);
-       up(&mcast_mutex);
+       spin_unlock(&ipoib_mcast_lock);
 
        if (priv->broadcast && priv->broadcast->query) {
                ib_sa_cancel_query(priv->broadcast->query_id, 
priv->broadcast->query);
-               priv->broadcast->query = NULL;
+               spin_unlock_irqrestore(&priv->lock, flags);
                ipoib_dbg_mcast(priv, "waiting for bcast\n");
                wait_for_completion(&priv->broadcast->done);
+               spin_lock_irqsave(&priv->lock, flags);
        }
 
        list_for_each_entry(mcast, &priv->multicast_list, list) {
                if (mcast->query) {
                        ib_sa_cancel_query(mcast->query_id, mcast->query);
-                       mcast->query = NULL;
+                       spin_unlock_irqrestore(&priv->lock, flags);
                        ipoib_dbg_mcast(priv, "waiting for MGID " IPOIB_GID_FMT 
"\n",
                                        IPOIB_GID_ARG(mcast->mcmember.mgid));
                        wait_for_completion(&mcast->done);
+                       spin_lock_irqsave(&priv->lock, flags);
                }
        }
+
+       spin_unlock_irqrestore(&priv->lock, flags);
 }
 
 static int ipoib_mcast_leave(struct net_device *dev, struct ipoib_mcast *mcast)
@@ -621,6 +656,7 @@ static int ipoib_mcast_leave(struct net_
        struct ib_sa_mcmember_rec rec = {
                .join_state = 1
        };
+       struct ib_sa_query *query;
        int ret = 0;
 
        if (!test_and_clear_bit(IPOIB_MCAST_FLAG_ATTACHED, &mcast->flags))
@@ -629,6 +665,8 @@ static int ipoib_mcast_leave(struct net_
        ipoib_dbg_mcast(priv, "leaving MGID " IPOIB_GID_FMT "\n",
                        IPOIB_GID_ARG(mcast->mcmember.mgid));
 
+       BUG_ON(mcast->query);
+
        rec.mgid     = mcast->mcmember.mgid;
        rec.port_gid = priv->local_gid;
        rec.pkey     = cpu_to_be16(priv->pkey);
@@ -649,7 +687,7 @@ static int ipoib_mcast_leave(struct net_
                                        IB_SA_MCMEMBER_REC_PKEY         |
                                        IB_SA_MCMEMBER_REC_JOIN_STATE,
                                        0, GFP_ATOMIC, NULL,
-                                       mcast, &mcast->query);
+                                       mcast, &query);
        if (ret < 0)
                ipoib_warn(priv, "ib_sa_mcmember_rec_delete failed "
                           "for leave (result = %d)\n", ret);
@@ -675,7 +713,7 @@ void ipoib_mcast_send(struct net_device 
                ipoib_dbg_mcast(priv, "setting up send only multicast group for 
"
                                IPOIB_GID_FMT "\n", IPOIB_GID_ARG(*mgid));
 
-               mcast = ipoib_mcast_alloc(dev, 0);
+               mcast = ipoib_mcast_alloc(dev);
                if (!mcast) {
                        ipoib_warn(priv, "unable to allocate memory for "
                                   "multicast structure\n");
@@ -741,7 +779,7 @@ static void ipoib_mcast_dev_flush(struct
 
        spin_lock_irqsave(&priv->lock, flags);
        list_for_each_entry_safe(mcast, tmcast, &priv->multicast_list, list) {
-               nmcast = ipoib_mcast_alloc(dev, 0);
+               nmcast = ipoib_mcast_alloc(dev);
                if (nmcast) {
                        nmcast->flags =
                                mcast->flags & (1 << IPOIB_MCAST_FLAG_SENDONLY);
@@ -764,17 +802,16 @@ static void ipoib_mcast_dev_flush(struct
        }
 
        if (priv->broadcast) {
-               nmcast = ipoib_mcast_alloc(dev, 0);
+               nmcast = ipoib_mcast_alloc(dev);
                if (nmcast) {
                        nmcast->mcmember.mgid = priv->broadcast->mcmember.mgid;
 
                        rb_replace_node(&priv->broadcast->rb_node,
                                        &nmcast->rb_node,
                                        &priv->multicast_tree);
-
-                       list_add_tail(&priv->broadcast->list, &remove_list);
                }
 
+               list_add_tail(&priv->broadcast->list, &remove_list);
                priv->broadcast = nmcast;
        }
 
@@ -789,19 +826,23 @@ static void ipoib_mcast_dev_flush(struct
 static void ipoib_mcast_dev_down(struct net_device *dev)
 {
        struct ipoib_dev_priv *priv = netdev_priv(dev);
+       struct ipoib_mcast *mcast;
        unsigned long flags;
 
+       spin_lock_irqsave(&priv->lock, flags);
+
        /* Delete broadcast since it will be recreated */
        if (priv->broadcast) {
                ipoib_dbg_mcast(priv, "deleting broadcast group\n");
 
-               spin_lock_irqsave(&priv->lock, flags);
                rb_erase(&priv->broadcast->rb_node, &priv->multicast_tree);
-               spin_unlock_irqrestore(&priv->lock, flags);
-               ipoib_mcast_leave(dev, priv->broadcast);
-               ipoib_mcast_free(priv->broadcast);
+               mcast = priv->broadcast;
                priv->broadcast = NULL;
-       }
+               spin_unlock_irqrestore(&priv->lock, flags);
+               ipoib_mcast_leave(dev, mcast);
+               ipoib_mcast_free(mcast);
+       } else
+               spin_unlock_irqrestore(&priv->lock, flags);
 }
 
 void ipoib_mcast_restart_task(void *dev_ptr)
@@ -847,7 +888,7 @@ void ipoib_mcast_restart_task(void *dev_
                        ipoib_dbg_mcast(priv, "adding multicast entry for mgid "
                                        IPOIB_GID_FMT "\n", 
IPOIB_GID_ARG(mgid));
 
-                       nmcast = ipoib_mcast_alloc(dev, 0);
+                       nmcast = ipoib_mcast_alloc(dev);
                        if (!nmcast) {
                                ipoib_warn(priv, "unable to allocate memory for 
multicast structure\n");
                                continue;
Index: linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib.h
===================================================================
--- linux-2.6.14-dbg.orig/drivers/infiniband/ulp/ipoib/ipoib.h  2005-11-20 
12:18:43.000000000 +0200
+++ linux-2.6.14-dbg/drivers/infiniband/ulp/ipoib/ipoib.h       2005-11-20 
14:56:53.000000000 +0200
@@ -226,6 +226,7 @@ static inline struct ipoib_neigh **to_ip
 }
 
 extern struct workqueue_struct *ipoib_workqueue;
+extern spinlock_t ipoib_mcast_lock;
 
 /* functions */
 
-- 
MST
_______________________________________________
openib-general mailing list
[email protected]
http://openib.org/mailman/listinfo/openib-general

To unsubscribe, please visit http://openib.org/mailman/listinfo/openib-general

Reply via email to