Replace the dead flag in the session context with a closing flag and
spinlock. Check it in session lookup functions such that we don't try
to access session data while it is being destroyed.

Signed-off-by: James Chapman <jchap...@katalix.com>
---
 net/l2tp/l2tp_core.c | 34 +++++++++++++++++++++++++++++++++-
 net/l2tp/l2tp_core.h |  3 ++-
 2 files changed, 35 insertions(+), 2 deletions(-)

diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c
index 477b96cf8ab3..869dec89ff0f 100644
--- a/net/l2tp/l2tp_core.c
+++ b/net/l2tp/l2tp_core.c
@@ -198,7 +198,14 @@ struct l2tp_session *l2tp_session_get(const struct net 
*net,
                rcu_read_lock_bh();
                hlist_for_each_entry_rcu(session, session_list, global_hlist) {
                        if (session->session_id == session_id) {
+                               spin_lock_bh(&session->lock);
+                               if (session->closing) {
+                                       spin_unlock_bh(&session->lock);
+                                       rcu_read_unlock_bh();
+                                       return NULL;
+                               }
                                l2tp_session_inc_refcount(session);
+                               spin_unlock_bh(&session->lock);
                                rcu_read_unlock_bh();
 
                                return session;
@@ -213,7 +220,14 @@ struct l2tp_session *l2tp_session_get(const struct net 
*net,
        read_lock_bh(&tunnel->hlist_lock);
        hlist_for_each_entry(session, session_list, hlist) {
                if (session->session_id == session_id) {
+                       spin_lock_bh(&session->lock);
+                       if (session->closing) {
+                               spin_unlock_bh(&session->lock);
+                               read_unlock_bh(&tunnel->hlist_lock);
+                               return NULL;
+                       }
                        l2tp_session_inc_refcount(session);
+                       spin_unlock_bh(&session->lock);
                        read_unlock_bh(&tunnel->hlist_lock);
 
                        return session;
@@ -234,6 +248,12 @@ struct l2tp_session *l2tp_session_get_nth(struct 
l2tp_tunnel *tunnel, int nth)
        read_lock_bh(&tunnel->hlist_lock);
        for (hash = 0; hash < L2TP_HASH_SIZE; hash++) {
                hlist_for_each_entry(session, &tunnel->session_hlist[hash], 
hlist) {
+                       spin_lock_bh(&session->lock);
+                       if (session->closing) {
+                               spin_unlock_bh(&session->lock);
+                               continue;
+                       }
+                       spin_unlock_bh(&session->lock);
                        if (++count > nth) {
                                l2tp_session_inc_refcount(session);
                                read_unlock_bh(&tunnel->hlist_lock);
@@ -261,6 +281,12 @@ struct l2tp_session *l2tp_session_get_by_ifname(const 
struct net *net,
        rcu_read_lock_bh();
        for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) {
                hlist_for_each_entry_rcu(session, 
&pn->l2tp_session_hlist[hash], global_hlist) {
+                       spin_lock_bh(&session->lock);
+                       if (session->closing) {
+                               spin_unlock_bh(&session->lock);
+                               continue;
+                       }
+                       spin_unlock_bh(&session->lock);
                        if (!strcmp(session->ifname, ifname)) {
                                l2tp_session_inc_refcount(session);
                                rcu_read_unlock_bh();
@@ -1678,8 +1704,13 @@ void __l2tp_session_unhash(struct l2tp_session *session)
  */
 int l2tp_session_delete(struct l2tp_session *session)
 {
-       if (test_and_set_bit(0, &session->dead))
+       spin_lock_bh(&session->lock);
+       if (session->closing) {
+               spin_unlock_bh(&session->lock);
                return 0;
+       }
+       session->closing = true;
+       spin_unlock_bh(&session->lock);
 
        __l2tp_session_unhash(session);
        l2tp_session_queue_purge(session);
@@ -1747,6 +1778,7 @@ struct l2tp_session *l2tp_session_create(int priv_size, 
struct l2tp_tunnel *tunn
 
                INIT_HLIST_NODE(&session->hlist);
                INIT_HLIST_NODE(&session->global_hlist);
+               spin_lock_init(&session->lock);
 
                /* Inherit debug options from tunnel */
                session->debug = tunnel->debug;
diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h
index 4e098c822cd1..98709086fe84 100644
--- a/net/l2tp/l2tp_core.h
+++ b/net/l2tp/l2tp_core.h
@@ -74,7 +74,8 @@ struct l2tp_session_cfg {
 struct l2tp_session {
        int                     magic;          /* should be
                                                 * L2TP_SESSION_MAGIC */
-       long                    dead;
+       bool                    closing;
+       spinlock_t              lock;           /* protect closing */
 
        struct l2tp_tunnel      *tunnel;        /* back pointer to tunnel
                                                 * context */
-- 
1.9.1

Reply via email to