Re: [PATCH v3 1/2] pvcalls-front: introduce a per sock_mapping refcount

2018-02-15 Thread Juergen Gross
On 14/02/18 19:28, Stefano Stabellini wrote:
> Introduce a per sock_mapping refcount, in addition to the existing
> global refcount. Thanks to the sock_mapping refcount, we can safely wait
> for it to be 1 in pvcalls_front_release before freeing an active socket,
> instead of waiting for the global refcount to be 1.
> 
> Signed-off-by: Stefano Stabellini 

Acked-by: Juergen Gross 


Juergen


[PATCH v3 1/2] pvcalls-front: introduce a per sock_mapping refcount

2018-02-14 Thread Stefano Stabellini
Introduce a per sock_mapping refcount, in addition to the existing
global refcount. Thanks to the sock_mapping refcount, we can safely wait
for it to be 1 in pvcalls_front_release before freeing an active socket,
instead of waiting for the global refcount to be 1.

Signed-off-by: Stefano Stabellini 

---
Changes in v3:
- remove pointless initializers
- reorder pvcalls_enter_sock
---
 drivers/xen/pvcalls-front.c | 191 ++--
 1 file changed, 79 insertions(+), 112 deletions(-)

diff --git a/drivers/xen/pvcalls-front.c b/drivers/xen/pvcalls-front.c
index 4c789e6..18d1bac 100644
--- a/drivers/xen/pvcalls-front.c
+++ b/drivers/xen/pvcalls-front.c
@@ -60,6 +60,7 @@ struct sock_mapping {
bool active_socket;
struct list_head list;
struct socket *sock;
+   atomic_t refcount;
union {
struct {
int irq;
@@ -93,6 +94,32 @@ struct sock_mapping {
};
 };
 
+static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock)
+{
+   struct sock_mapping *map;
+
+   if (!pvcalls_front_dev ||
+   dev_get_drvdata(&pvcalls_front_dev->dev) == NULL)
+   return ERR_PTR(-ENOTCONN);
+
+   map = (struct sock_mapping *)sock->sk->sk_send_head;
+   if (map == NULL)
+   return ERR_PTR(-ENOTSOCK);
+
+   pvcalls_enter();
+   atomic_inc(&map->refcount);
+   return map;
+}
+
+static inline void pvcalls_exit_sock(struct socket *sock)
+{
+   struct sock_mapping *map;
+
+   map = (struct sock_mapping *)sock->sk->sk_send_head;
+   atomic_dec(&map->refcount);
+   pvcalls_exit();
+}
+
 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id)
 {
*req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1);
@@ -369,31 +396,23 @@ int pvcalls_front_connect(struct socket *sock, struct 
sockaddr *addr,
if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM)
return -EOPNOTSUPP;
 
-   pvcalls_enter();
-   if (!pvcalls_front_dev) {
-   pvcalls_exit();
-   return -ENOTCONN;
-   }
+   map = pvcalls_enter_sock(sock);
+   if (IS_ERR(map))
+   return PTR_ERR(map);
 
bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 
-   map = (struct sock_mapping *)sock->sk->sk_send_head;
-   if (!map) {
-   pvcalls_exit();
-   return -ENOTSOCK;
-   }
-
spin_lock(&bedata->socket_lock);
ret = get_request(bedata, &req_id);
if (ret < 0) {
spin_unlock(&bedata->socket_lock);
-   pvcalls_exit();
+   pvcalls_exit_sock(sock);
return ret;
}
ret = create_active(map, &evtchn);
if (ret < 0) {
spin_unlock(&bedata->socket_lock);
-   pvcalls_exit();
+   pvcalls_exit_sock(sock);
return ret;
}
 
@@ -423,7 +442,7 @@ int pvcalls_front_connect(struct socket *sock, struct 
sockaddr *addr,
smp_rmb();
ret = bedata->rsp[req_id].ret;
bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID;
-   pvcalls_exit();
+   pvcalls_exit_sock(sock);
return ret;
 }
 
@@ -488,23 +507,15 @@ int pvcalls_front_sendmsg(struct socket *sock, struct 
msghdr *msg,
if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB))
return -EOPNOTSUPP;
 
-   pvcalls_enter();
-   if (!pvcalls_front_dev) {
-   pvcalls_exit();
-   return -ENOTCONN;
-   }
+   map = pvcalls_enter_sock(sock);
+   if (IS_ERR(map))
+   return PTR_ERR(map);
bedata = dev_get_drvdata(&pvcalls_front_dev->dev);
 
-   map = (struct sock_mapping *) sock->sk->sk_send_head;
-   if (!map) {
-   pvcalls_exit();
-   return -ENOTSOCK;
-   }
-
mutex_lock(&map->active.out_mutex);
if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) {
mutex_unlock(&map->active.out_mutex);
-   pvcalls_exit();
+   pvcalls_exit_sock(sock);
return -EAGAIN;
}
if (len > INT_MAX)
@@ -526,7 +537,7 @@ int pvcalls_front_sendmsg(struct socket *sock, struct 
msghdr *msg,
tot_sent = sent;
 
mutex_unlock(&map->active.out_mutex);
-   pvcalls_exit();
+   pvcalls_exit_sock(sock);
return tot_sent;
 }
 
@@ -591,19 +602,11 @@ int pvcalls_front_recvmsg(struct socket *sock, struct 
msghdr *msg, size_t len,
if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC))
return -EOPNOTSUPP;
 
-   pvcalls_enter();
-   if (!pvcalls_front_dev) {
-   pvcalls_exit();
-   return -ENOTCONN;
-   }
+   map = pvcalls_enter_sock(sock);
+   if (IS_ERR(map))
+   return PTR_ERR(map);
bedata = dev_get_drvdata(&pvcalls_front_dev->