On Mon, 12 Feb 2018, Juergen Gross wrote:
> On 05/02/18 23:51, Stefano Stabellini wrote:
> > Introduce a per sock_mapping refcount, in addition to the existing
> > global refcount. Thanks to the sock_mapping refcount, we can safely wait
> > for it to be 1 in pvcalls_front_release before freeing an active socket,
> > instead of waiting for the global refcount to be 1.
> >
> > Signed-off-by: Stefano Stabellini
> > ---
> > drivers/xen/pvcalls-front.c | 190
> > ++--
> > 1 file changed, 78 insertions(+), 112 deletions(-)
> >
> > diff --git a/drivers/xen/pvcalls-front.c b/drivers/xen/pvcalls-front.c
> > index 4c789e6..164d3ad 100644
> > --- a/drivers/xen/pvcalls-front.c
> > +++ b/drivers/xen/pvcalls-front.c
> > @@ -60,6 +60,7 @@ struct sock_mapping {
> > bool active_socket;
> > struct list_head list;
> > struct socket *sock;
> > + atomic_t refcount;
> > union {
> > struct {
> > int irq;
> > @@ -93,6 +94,33 @@ struct sock_mapping {
> > };
> > };
> >
> > +static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
> > +{
> > + struct sock_mapping *map = NULL;
> > +
> > + if (!pvcalls_front_dev || _front_dev->dev == NULL)
>
> Did you mean:
> if (!pvcalls_front_dev || !pvcalls_front_dev->dev)
I actually meant:
if (!pvcalls_front_dev || dev_get_drvdata(_front_dev->dev) == NULL)
> > + return ERR_PTR(-ENOTCONN);
> > +
> > + pvcalls_enter();
> > + map = (struct sock_mapping *) sock->sk->sk_send_head;
>
> Style: no blank after the cast, please (multiple times).
OK, I'll fix them all
> > + if (map == NULL) {
> > + pvcalls_exit()
> > + return ERR_PTR(-ENOTSOCK);
> > + }
> > +
> > + atomic_inc(>refcount);
> > + return map;
> > +}
> > +
> > +static inline void pvcalls_exit_sock(struct socket *sock)
> > +{
> > + struct sock_mapping *map = NULL;
> > +
> > + map = (struct sock_mapping *) sock->sk->sk_send_head;
> > + atomic_dec(>refcount);
> > + pvcalls_exit();
> > +}
> > +
> > static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
> > {
> > *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(>ring) - 1);
> > @@ -369,31 +397,23 @@ int pvcalls_front_connect(struct socket *sock, struct
> > sockaddr *addr,
> > if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
> > return -EOPNOTSUPP;
> >
> > - pvcalls_enter();
> > - if (!pvcalls_front_dev) {
> > - pvcalls_exit();
> > - return -ENOTCONN;
> > - }
> > + map = pvcalls_enter_sock(sock);
> > + if (IS_ERR(map))
> > + return PTR_ERR(map);
> >
> > bedata = dev_get_drvdata(_front_dev->dev);
> >
> > - map = (struct sock_mapping *)sock->sk->sk_send_head;
> > - if (!map) {
> > - pvcalls_exit();
> > - return -ENOTSOCK;
> > - }
> > -
> > spin_lock(>socket_lock);
> > ret = get_request(bedata, _id);
> > if (ret < 0) {
> > spin_unlock(>socket_lock);
> > - pvcalls_exit();
> > + pvcalls_exit_sock(sock);
> > return ret;
> > }
> > ret = create_active(map, );
> > if (ret < 0) {
> > spin_unlock(>socket_lock);
> > - pvcalls_exit();
> > + pvcalls_exit_sock(sock);
> > return ret;
> > }
> >
> > @@ -423,7 +443,7 @@ int pvcalls_front_connect(struct socket *sock, struct
> > sockaddr *addr,
> > smp_rmb();
> > ret = bedata->rsp[req_id].ret;
> > bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
> > - pvcalls_exit();
> > + pvcalls_exit_sock(sock);
> > return ret;
> > }
> >
> > @@ -488,23 +508,15 @@ int pvcalls_front_sendmsg(struct socket *sock, struct
> > msghdr *msg,
> > if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
> > return -EOPNOTSUPP;
> >
> > - pvcalls_enter();
> > - if (!pvcalls_front_dev) {
> > - pvcalls_exit();
> > - return -ENOTCONN;
> > - }
> > + map = pvcalls_enter_sock(sock);
> > + if (IS_ERR(map))
> > + return PTR_ERR(map);
> > bedata = dev_get_drvdata(_front_dev->dev);
> >
> > - map = (struct sock_mapping *) sock->sk->sk_send_head;
> > - if (!map) {
> > - pvcalls_exit();
> > - return -ENOTSOCK;
> > - }
> > -
> > mutex_lock(>active.out_mutex);
> > if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
> > mutex_unlock(>active.out_mutex);
> > - pvcalls_exit();
> > + pvcalls_exit_sock(sock);
> > return -EAGAIN;
> > }
> > if (len > INT_MAX)
> > @@ -526,7 +538,7 @@ int pvcalls_front_sendmsg(struct socket *sock, struct
> > msghdr *msg,
> > tot_sent = sent;
> >
> > mutex_unlock(>active.out_mutex);
> > - pvcalls_exit();
> > + pvcalls_exit_sock(sock);
> > return tot_sent;
> > }
> >
> > @@ -591,19 +603,11 @@ int pvcalls_front_recvmsg(struct