Simplify relationship with tunnel such that the session holds a ref on the tunnel, not its socket. This guarantees that the tunnel is always extant if one or more sessions exists on the tunnel. If the session has a socket (ppp), have it hold a ref on the socket until the session is destroyed.
Since pppol2tp_sock_to_session returns a session and the session now holds a sock ref, have it return with a ref on the session. Fixes: fd558d186df2c ("l2tp: Split pppol2tp patch into separate l2tp and ppp parts") Fixes: f3c66d4e144a0 ("l2tp: prevent creation of sessions on terminated tunnels") --- net/l2tp/l2tp_core.c | 7 ++----- net/l2tp/l2tp_ppp.c | 36 ++++++++++++++++++------------------ 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index 691fe9368d91..477b96cf8ab3 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -290,6 +290,7 @@ int l2tp_session_register(struct l2tp_session *session, spin_unlock_bh(&tunnel->lock); return -ENODEV; } + l2tp_tunnel_inc_refcount(tunnel); spin_unlock_bh(&tunnel->lock); head = l2tp_session_id_hash(tunnel, session->session_id); @@ -315,14 +316,9 @@ int l2tp_session_register(struct l2tp_session *session, goto err_tlock_pnlock; } - l2tp_tunnel_inc_refcount(tunnel); - sock_hold(tunnel->sock); hlist_add_head_rcu(&session->global_hlist, g_head); spin_unlock_bh(&pn->l2tp_session_hlist_lock); - } else { - l2tp_tunnel_inc_refcount(tunnel); - sock_hold(tunnel->sock); } hlist_add_head(&session->hlist, head); @@ -334,6 +330,7 @@ int l2tp_session_register(struct l2tp_session *session, spin_unlock_bh(&pn->l2tp_session_hlist_lock); err_tlock: write_unlock_bh(&tunnel->hlist_lock); + l2tp_tunnel_dec_refcount(tunnel); return err; } diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c index fe5a0043dd32..ff95a4d4eac5 100644 --- a/net/l2tp/l2tp_ppp.c +++ b/net/l2tp/l2tp_ppp.c @@ -166,16 +166,17 @@ static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk) if (sk == NULL) return NULL; - sock_hold(sk); - session = (struct l2tp_session *)(sk->sk_user_data); + rcu_read_lock_bh(); + session = rcu_dereference_bh(__sk_user_data((sk))); if (session == NULL) { - sock_put(sk); - goto out; + rcu_read_unlock_bh(); + return NULL; } + l2tp_session_inc_refcount(session); + rcu_read_unlock(); BUG_ON(session->magic != L2TP_SESSION_MAGIC); -out: return session; } @@ -243,8 +244,8 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int /* If the socket is bound, send it in to PPP's input queue. Otherwise * queue it on the session socket. */ - rcu_read_lock(); - sk = rcu_dereference(ps->sk); + rcu_read_lock_bh(); + sk = rcu_dereference_bh(ps->sk); if (sk == NULL) goto no_sock; @@ -267,12 +268,12 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int kfree_skb(skb); } } - rcu_read_unlock(); + rcu_read_unlock_bh(); return; no_sock: - rcu_read_unlock(); + rcu_read_unlock_bh(); l2tp_info(session, L2TP_MSG_DATA, "%s: no socket\n", session->name); kfree_skb(skb); } @@ -341,12 +342,12 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m, l2tp_xmit_skb(session, skb, session->hdr_len); local_bh_enable(); - sock_put(sk); + l2tp_session_dec_refcount(session); return total_len; error_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); error: return error; } @@ -400,12 +401,12 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb) l2tp_xmit_skb(session, skb, session->hdr_len); local_bh_enable(); - sock_put(sk); + l2tp_session_dec_refcount(session); return 1; abort_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); abort: /* Free the original skb */ kfree_skb(skb); @@ -483,7 +484,6 @@ static int pppol2tp_release(struct socket *sock) sock->sk = NULL; session = pppol2tp_sock_to_session(sk); - if (session != NULL) { struct pppol2tp_session *ps; @@ -976,7 +976,7 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr, *usockaddr_len = len; error = 0; - sock_put(sk); + l2tp_session_dec_refcount(session); end: return error; } @@ -1247,7 +1247,7 @@ static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd, err = pppol2tp_session_ioctl(session, cmd, arg); end_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); end: return err; } @@ -1398,7 +1398,7 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname, err = pppol2tp_session_setsockopt(sk, session, optname, val); } - sock_put(sk); + l2tp_session_dec_refcount(session); end: return err; } @@ -1530,7 +1530,7 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname, err = 0; end_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); end: return err; } -- 1.9.1