From: Andrey Ignatov <r...@fb.com>

Refactor `bind()` code to make it ready to be called from BPF helper
function `bpf_bind()` (will be added soon). Implementation of
`inet_bind()` and `inet6_bind()` is separated into `__inet_bind()` and
`__inet6_bind()` correspondingly. These function can be used from both
`sk_prot->bind` and `bpf_bind()` contexts.

New functions have two additional arguments.

`force_bind_address_no_port` forces binding to IP only w/o checking
`inet_sock.bind_address_no_port` field. It'll allow to bind local end of
a connection to desired IP in `bpf_bind()` w/o changing
`bind_address_no_port` field of a socket. It's useful since `bpf_bind()`
can return an error and we'd need to restore original value of
`bind_address_no_port` in that case if we changed this before calling to
the helper.

`with_lock` specifies whether to lock socket when working with `struct
sk` or not. The argument is set to `true` for `sk_prot->bind`, i.e. old
behavior is preserved. But it will be set to `false` for `bpf_bind()`
use-case. The reason is all call-sites, where `bpf_bind()` will be
called, already hold that socket lock.

Signed-off-by: Andrey Ignatov <r...@fb.com>
Acked-by: Alexei Starovoitov <a...@kernel.org>
Signed-off-by: Alexei Starovoitov <a...@kernel.org>
---
 include/net/inet_common.h |  2 ++
 include/net/ipv6.h        |  2 ++
 net/ipv4/af_inet.c        | 39 ++++++++++++++++++++++++---------------
 net/ipv6/af_inet6.c       | 37 ++++++++++++++++++++++++-------------
 4 files changed, 52 insertions(+), 28 deletions(-)

diff --git a/include/net/inet_common.h b/include/net/inet_common.h
index 500f81375200..384b90c62c0b 100644
--- a/include/net/inet_common.h
+++ b/include/net/inet_common.h
@@ -32,6 +32,8 @@ int inet_shutdown(struct socket *sock, int how);
 int inet_listen(struct socket *sock, int backlog);
 void inet_sock_destruct(struct sock *sk);
 int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len);
+int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len,
+               bool force_bind_address_no_port, bool with_lock);
 int inet_getname(struct socket *sock, struct sockaddr *uaddr,
                 int peer);
 int inet_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg);
diff --git a/include/net/ipv6.h b/include/net/ipv6.h
index 50a6f0ddb878..2e5fedc56e59 100644
--- a/include/net/ipv6.h
+++ b/include/net/ipv6.h
@@ -1066,6 +1066,8 @@ void ipv6_local_error(struct sock *sk, int err, struct 
flowi6 *fl6, u32 info);
 void ipv6_local_rxpmtu(struct sock *sk, struct flowi6 *fl6, u32 mtu);
 
 int inet6_release(struct socket *sock);
+int __inet6_bind(struct sock *sock, struct sockaddr *uaddr, int addr_len,
+                bool force_bind_address_no_port, bool with_lock);
 int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len);
 int inet6_getname(struct socket *sock, struct sockaddr *uaddr,
                  int peer);
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index 2dec266507dc..e203a39d6988 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -432,30 +432,37 @@ EXPORT_SYMBOL(inet_release);
 
 int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
 {
-       struct sockaddr_in *addr = (struct sockaddr_in *)uaddr;
        struct sock *sk = sock->sk;
-       struct inet_sock *inet = inet_sk(sk);
-       struct net *net = sock_net(sk);
-       unsigned short snum;
-       int chk_addr_ret;
-       u32 tb_id = RT_TABLE_LOCAL;
        int err;
 
        /* If the socket has its own bind function then use it. (RAW) */
        if (sk->sk_prot->bind) {
-               err = sk->sk_prot->bind(sk, uaddr, addr_len);
-               goto out;
+               return sk->sk_prot->bind(sk, uaddr, addr_len);
        }
-       err = -EINVAL;
        if (addr_len < sizeof(struct sockaddr_in))
-               goto out;
+               return -EINVAL;
 
        /* BPF prog is run before any checks are done so that if the prog
         * changes context in a wrong way it will be caught.
         */
        err = BPF_CGROUP_RUN_PROG_INET4_BIND(sk, uaddr);
        if (err)
-               goto out;
+               return err;
+
+       return __inet_bind(sk, uaddr, addr_len, false, true);
+}
+EXPORT_SYMBOL(inet_bind);
+
+int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len,
+               bool force_bind_address_no_port, bool with_lock)
+{
+       struct sockaddr_in *addr = (struct sockaddr_in *)uaddr;
+       struct inet_sock *inet = inet_sk(sk);
+       struct net *net = sock_net(sk);
+       unsigned short snum;
+       int chk_addr_ret;
+       u32 tb_id = RT_TABLE_LOCAL;
+       int err;
 
        if (addr->sin_family != AF_INET) {
                /* Compatibility games : accept AF_UNSPEC (mapped to AF_INET)
@@ -499,7 +506,8 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, 
int addr_len)
         *      would be illegal to use them (multicast/broadcast) in
         *      which case the sending device address is used.
         */
-       lock_sock(sk);
+       if (with_lock)
+               lock_sock(sk);
 
        /* Check these errors (active socket, double bind). */
        err = -EINVAL;
@@ -511,7 +519,8 @@ int inet_bind(struct socket *sock, struct sockaddr *uaddr, 
int addr_len)
                inet->inet_saddr = 0;  /* Use device */
 
        /* Make sure we are allowed to bind here. */
-       if ((snum || !inet->bind_address_no_port) &&
+       if ((snum || !(inet->bind_address_no_port ||
+                      force_bind_address_no_port)) &&
            sk->sk_prot->get_port(sk, snum)) {
                inet->inet_saddr = inet->inet_rcv_saddr = 0;
                err = -EADDRINUSE;
@@ -528,11 +537,11 @@ int inet_bind(struct socket *sock, struct sockaddr 
*uaddr, int addr_len)
        sk_dst_reset(sk);
        err = 0;
 out_release_sock:
-       release_sock(sk);
+       if (with_lock)
+               release_sock(sk);
 out:
        return err;
 }
-EXPORT_SYMBOL(inet_bind);
 
 int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr,
                       int addr_len, int flags)
diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c
index fa24e3f06ac6..13110bee5c14 100644
--- a/net/ipv6/af_inet6.c
+++ b/net/ipv6/af_inet6.c
@@ -277,15 +277,7 @@ static int inet6_create(struct net *net, struct socket 
*sock, int protocol,
 /* bind for INET6 API */
 int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
 {
-       struct sockaddr_in6 *addr = (struct sockaddr_in6 *)uaddr;
        struct sock *sk = sock->sk;
-       struct inet_sock *inet = inet_sk(sk);
-       struct ipv6_pinfo *np = inet6_sk(sk);
-       struct net *net = sock_net(sk);
-       __be32 v4addr = 0;
-       unsigned short snum;
-       bool saved_ipv6only;
-       int addr_type = 0;
        int err = 0;
 
        /* If the socket has its own bind function then use it. */
@@ -302,11 +294,28 @@ int inet6_bind(struct socket *sock, struct sockaddr 
*uaddr, int addr_len)
        if (err)
                return err;
 
+       return __inet6_bind(sk, uaddr, addr_len, false, true);
+}
+EXPORT_SYMBOL(inet6_bind);
+
+int __inet6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len,
+                bool force_bind_address_no_port, bool with_lock)
+{
+       struct sockaddr_in6 *addr = (struct sockaddr_in6 *)uaddr;
+       struct inet_sock *inet = inet_sk(sk);
+       struct ipv6_pinfo *np = inet6_sk(sk);
+       struct net *net = sock_net(sk);
+       __be32 v4addr = 0;
+       unsigned short snum;
+       bool saved_ipv6only;
+       int addr_type = 0;
+       int err = 0;
+
        if (addr->sin6_family != AF_INET6)
                return -EAFNOSUPPORT;
 
        addr_type = ipv6_addr_type(&addr->sin6_addr);
-       if ((addr_type & IPV6_ADDR_MULTICAST) && sock->type == SOCK_STREAM)
+       if ((addr_type & IPV6_ADDR_MULTICAST) && sk->sk_type == SOCK_STREAM)
                return -EINVAL;
 
        snum = ntohs(addr->sin6_port);
@@ -314,7 +323,8 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, 
int addr_len)
            !ns_capable(net->user_ns, CAP_NET_BIND_SERVICE))
                return -EACCES;
 
-       lock_sock(sk);
+       if (with_lock)
+               lock_sock(sk);
 
        /* Check these errors (active socket, double bind). */
        if (sk->sk_state != TCP_CLOSE || inet->inet_num) {
@@ -402,7 +412,8 @@ int inet6_bind(struct socket *sock, struct sockaddr *uaddr, 
int addr_len)
                sk->sk_ipv6only = 1;
 
        /* Make sure we are allowed to bind here. */
-       if ((snum || !inet->bind_address_no_port) &&
+       if ((snum || !(inet->bind_address_no_port ||
+                      force_bind_address_no_port)) &&
            sk->sk_prot->get_port(sk, snum)) {
                sk->sk_ipv6only = saved_ipv6only;
                inet_reset_saddr(sk);
@@ -418,13 +429,13 @@ int inet6_bind(struct socket *sock, struct sockaddr 
*uaddr, int addr_len)
        inet->inet_dport = 0;
        inet->inet_daddr = 0;
 out:
-       release_sock(sk);
+       if (with_lock)
+               release_sock(sk);
        return err;
 out_unlock:
        rcu_read_unlock();
        goto out;
 }
-EXPORT_SYMBOL(inet6_bind);
 
 int inet6_release(struct socket *sock)
 {
-- 
2.9.5

Reply via email to