This patch adds a BPF_PROG_TYPE_SK_REUSEPORT which can select
a SO_REUSEPORT sk from a BPF_MAP_TYPE_REUSEPORT_ARRAY.  Like other
non SK_FILTER/CGROUP_SKB program, it requires CAP_SYS_ADMIN.

BPF_PROG_TYPE_SK_REUSEPORT introduces "struct sk_reuseport_kern"
to store the bpf context instead of using the skb->cb[48].

At the SO_REUSEPORT sk lookup time, it is in the middle of transiting
from a lower layer (ipv4/ipv6) to a upper layer (udp/tcp).  At this
point,  it is not always clear where the bpf context can be appended
in the skb->cb[48] to avoid saving-and-restoring cb[].  Even putting
aside the difference between ipv4-vs-ipv6 and udp-vs-tcp.  It is not
clear if the lower layer is only ipv4 and ipv6 in the future and
will it not touch the cb[] again before transiting to the upper
layer.

For example, in udp_gro_receive(), it uses the 48 byte NAPI_GRO_CB
instead of IP[6]CB and it may still modify the cb[] after calling
the udp[46]_lib_lookup_skb().  Because of the above reason, if
sk->cb is used for the bpf ctx, saving-and-restoring is needed
and likely the whole 48 bytes cb[] has to be saved and restored.

Instead of saving, setting and restoring the cb[], this patch opts
to create a new "struct sk_reuseport_kern" and setting the needed
values in there.

The new BPF_PROG_TYPE_SK_REUSEPORT and "struct sk_reuseport_(kern|md)"
will serve all ipv4/ipv6 + udp/tcp combinations.  There is no protocol
specific usage at this point and it is also inline with the current
sock_reuseport.c implementation (i.e. no protocol specific requirement).

In "struct sk_reuseport_md", this patch exposes data/data_end/len
with semantic similar to other existing usages.  Together
with "bpf_skb_load_bytes()" and "bpf_skb_load_bytes_relative()",
the bpf prog can peek anywhere in the skb.  The "bind_inany" tells
the bpf prog that the reuseport group is bind-ed to a local
INANY address which cannot be learned from skb.

The new "bind_inany" is added to "struct sock_reuseport" which will be
used when running the new "BPF_PROG_TYPE_SK_REUSEPORT" bpf prog in order
to avoid repeating the "bind INANY" test on
"sk_v6_rcv_saddr/sk->sk_rcv_saddr" every time a bpf prog is run.  It can
only be properly initialized when a "sk->sk_reuseport" enabled sk is
adding to a hashtable (i.e. during "reuseport_alloc()" and
"reuseport_add_sock()").

The new "sk_select_reuseport()" is the main helper that the
bpf prog will use to select a SO_REUSEPORT sk.  It is the only function
that can use the new BPF_MAP_TYPE_REUSEPORT_ARRAY.  As mentioned in
the earlier patch, the validity of a selected sk is checked in
run time in "sk_select_reuseport()".  Doing the check in
verification time is difficult and inflexible (consider the map-in-map
use case).  The runtime check is to compare the selected sk's reuseport_id
with the reuseport_id that we want.  This helper will return -EXXX if the
selected sk cannot serve the incoming request (e.g. reuseport_id
not match).  The bpf prog can decide if it wants to do SK_DROP as its
discretion.

When the bpf prog returns SK_PASS, the kernel will check if a
valid sk has been selected (i.e. "reuse_kern->selected_sk != NULL").
If it does , it will use the selected sk.  If not, the kernel
will select one from "reuse->socks[]" (as before this patch).

The SK_DROP and SK_PASS handling logic will be in the next patch.

Signed-off-by: Martin KaFai Lau <ka...@fb.com>
Acked-by: Alexei Starovoitov <a...@kernel.org>
---
 include/linux/bpf_types.h       |   3 +
 include/linux/filter.h          |  15 ++
 include/net/addrconf.h          |   1 +
 include/net/sock_reuseport.h    |   6 +-
 include/uapi/linux/bpf.h        |  36 ++++-
 kernel/bpf/verifier.c           |   9 ++
 net/core/filter.c               | 269 +++++++++++++++++++++++++++++++-
 net/core/sock_reuseport.c       |  20 ++-
 net/ipv4/inet_connection_sock.c |   9 ++
 net/ipv4/inet_hashtables.c      |   5 +-
 net/ipv4/udp.c                  |   5 +-
 11 files changed, 365 insertions(+), 13 deletions(-)

diff --git a/include/linux/bpf_types.h b/include/linux/bpf_types.h
index 14fd6c02d258..cd26c090e7c0 100644
--- a/include/linux/bpf_types.h
+++ b/include/linux/bpf_types.h
@@ -29,6 +29,9 @@ BPF_PROG_TYPE(BPF_PROG_TYPE_CGROUP_DEVICE, cg_dev)
 #ifdef CONFIG_BPF_LIRC_MODE2
 BPF_PROG_TYPE(BPF_PROG_TYPE_LIRC_MODE2, lirc_mode2)
 #endif
+#ifdef CONFIG_INET
+BPF_PROG_TYPE(BPF_PROG_TYPE_SK_REUSEPORT, sk_reuseport)
+#endif
 
 BPF_MAP_TYPE(BPF_MAP_TYPE_ARRAY, array_map_ops)
 BPF_MAP_TYPE(BPF_MAP_TYPE_PERCPU_ARRAY, percpu_array_map_ops)
diff --git a/include/linux/filter.h b/include/linux/filter.h
index c73dd7396886..29577c6f3289 100644
--- a/include/linux/filter.h
+++ b/include/linux/filter.h
@@ -32,6 +32,7 @@ struct seccomp_data;
 struct bpf_prog_aux;
 struct xdp_rxq_info;
 struct xdp_buff;
+struct sock_reuseport;
 
 /* ArgX, context and stack frame pointer register positions. Note,
  * Arg1, Arg2, Arg3, etc are used as argument mappings of function
@@ -798,6 +799,20 @@ void bpf_warn_invalid_xdp_action(u32 act);
 struct sock *do_sk_redirect_map(struct sk_buff *skb);
 struct sock *do_msg_redirect_map(struct sk_msg_buff *md);
 
+#ifdef CONFIG_INET
+struct sock *bpf_run_sk_reuseport(struct sock_reuseport *reuse, struct sock 
*sk,
+                                 struct bpf_prog *prog, struct sk_buff *skb,
+                                 u32 hash);
+#else
+static inline struct sock *
+bpf_run_sk_reuseport(struct sock_reuseport *reuse, struct sock *sk,
+                    struct bpf_prog *prog, struct sk_buff *skb,
+                    u32 hash)
+{
+       return NULL;
+}
+#endif
+
 #ifdef CONFIG_BPF_JIT
 extern int bpf_jit_enable;
 extern int bpf_jit_harden;
diff --git a/include/net/addrconf.h b/include/net/addrconf.h
index 5f43f7a70fe6..6def0351bcc3 100644
--- a/include/net/addrconf.h
+++ b/include/net/addrconf.h
@@ -108,6 +108,7 @@ int ipv6_get_lladdr(struct net_device *dev, struct in6_addr 
*addr,
                    u32 banned_flags);
 bool inet_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
                          bool match_wildcard);
+bool inet_rcv_saddr_any(const struct sock *sk);
 void addrconf_join_solict(struct net_device *dev, const struct in6_addr *addr);
 void addrconf_leave_solict(struct inet6_dev *idev, const struct in6_addr 
*addr);
 
diff --git a/include/net/sock_reuseport.h b/include/net/sock_reuseport.h
index e1a7681856f7..73b569556be6 100644
--- a/include/net/sock_reuseport.h
+++ b/include/net/sock_reuseport.h
@@ -21,12 +21,14 @@ struct sock_reuseport {
        unsigned int            synq_overflow_ts;
        /* ID stays the same even after the size of socks[] grows. */
        unsigned int            reuseport_id;
+       bool                    bind_inany;
        struct bpf_prog __rcu   *prog;          /* optional BPF sock selector */
        struct sock             *socks[0];      /* array of sock pointers */
 };
 
-extern int reuseport_alloc(struct sock *sk);
-extern int reuseport_add_sock(struct sock *sk, struct sock *sk2);
+extern int reuseport_alloc(struct sock *sk, bool bind_inany);
+extern int reuseport_add_sock(struct sock *sk, struct sock *sk2,
+                             bool bind_inany);
 extern void reuseport_detach_sock(struct sock *sk);
 extern struct sock *reuseport_select_sock(struct sock *sk,
                                          u32 hash,
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index 40f584bc7da0..3102a2a23c31 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -151,6 +151,7 @@ enum bpf_prog_type {
        BPF_PROG_TYPE_CGROUP_SOCK_ADDR,
        BPF_PROG_TYPE_LWT_SEG6LOCAL,
        BPF_PROG_TYPE_LIRC_MODE2,
+       BPF_PROG_TYPE_SK_REUSEPORT,
 };
 
 enum bpf_attach_type {
@@ -2114,6 +2115,14 @@ union bpf_attr {
  *             the shared data.
  *     Return
  *             Pointer to the local storage area.
+ *
+ * int bpf_sk_select_reuseport(struct sk_reuseport_md *reuse, struct bpf_map 
*map, void *key, u64 flags)
+ *     Description
+ *             Select a SO_REUSEPORT sk from a BPF_MAP_TYPE_REUSEPORT_ARRAY map
+ *             It checks the selected sk is matching the incoming
+ *             request in the skb.
+ *     Return
+ *             0 on success, or a negative error in case of failure.
  */
 #define __BPF_FUNC_MAPPER(FN)          \
        FN(unspec),                     \
@@ -2197,7 +2206,8 @@ union bpf_attr {
        FN(rc_keydown),                 \
        FN(skb_cgroup_id),              \
        FN(get_current_cgroup_id),      \
-       FN(get_local_storage),
+       FN(get_local_storage),          \
+       FN(sk_select_reuseport),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
@@ -2414,6 +2424,30 @@ struct sk_msg_md {
        __u32 local_port;       /* stored in host byte order */
 };
 
+struct sk_reuseport_md {
+       /*
+        * Start of directly accessible data. It begins from
+        * the tcp/udp header.
+        */
+       void *data;
+       void *data_end;         /* End of directly accessible data */
+       /*
+        * Total length of packet (starting from the tcp/udp header).
+        * Note that the directly accessible bytes (data_end - data)
+        * could be less than this "len".  Those bytes could be
+        * indirectly read by a helper "bpf_skb_load_bytes()".
+        */
+       __u32 len;
+       /*
+        * Eth protocol in the mac header (network byte order). e.g.
+        * ETH_P_IP(0x0800) and ETH_P_IPV6(0x86DD)
+        */
+       __u32 eth_protocol;
+       __u32 ip_protocol;      /* IP protocol. e.g. IPPROTO_TCP, IPPROTO_UDP */
+       __u32 bind_inany;       /* Is sock bound to an INANY address? */
+       __u32 hash;             /* A hash of the packet 4 tuples */
+};
+
 #define BPF_TAG_SIZE   8
 
 struct bpf_prog_info {
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 587468a9c37d..ca90679a7fe5 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -1310,6 +1310,7 @@ static bool may_access_direct_pkt_data(struct 
bpf_verifier_env *env,
        case BPF_PROG_TYPE_LWT_IN:
        case BPF_PROG_TYPE_LWT_OUT:
        case BPF_PROG_TYPE_LWT_SEG6LOCAL:
+       case BPF_PROG_TYPE_SK_REUSEPORT:
                /* dst_input() and dst_output() can't write for now */
                if (t == BPF_WRITE)
                        return false;
@@ -2166,6 +2167,10 @@ static int check_map_func_compatibility(struct 
bpf_verifier_env *env,
                    func_id != BPF_FUNC_msg_redirect_hash)
                        goto error;
                break;
+       case BPF_MAP_TYPE_REUSEPORT_SOCKARRAY:
+               if (func_id != BPF_FUNC_sk_select_reuseport)
+                       goto error;
+               break;
        default:
                break;
        }
@@ -2217,6 +2222,10 @@ static int check_map_func_compatibility(struct 
bpf_verifier_env *env,
                if (map->map_type != BPF_MAP_TYPE_CGROUP_STORAGE)
                        goto error;
                break;
+       case BPF_FUNC_sk_select_reuseport:
+               if (map->map_type != BPF_MAP_TYPE_REUSEPORT_SOCKARRAY)
+                       goto error;
+               break;
        default:
                break;
        }
diff --git a/net/core/filter.c b/net/core/filter.c
index 56664c2f9cbb..f4c928709756 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -1462,7 +1462,7 @@ static int __reuseport_attach_prog(struct bpf_prog *prog, 
struct sock *sk)
                return -ENOMEM;
 
        if (sk_unhashed(sk) && sk->sk_reuseport) {
-               err = reuseport_alloc(sk);
+               err = reuseport_alloc(sk, false);
                if (err)
                        return err;
        } else if (!rcu_access_pointer(sk->sk_reuseport_cb)) {
@@ -7020,3 +7020,270 @@ int sk_get_filter(struct sock *sk, struct sock_filter 
__user *ubuf,
        release_sock(sk);
        return ret;
 }
+
+#ifdef CONFIG_INET
+struct sk_reuseport_kern {
+       struct sk_buff *skb;
+       struct sock *sk;
+       struct sock *selected_sk;
+       void *data_end;
+       u32 hash;
+       u32 reuseport_id;
+       bool bind_inany;
+};
+
+static void bpf_init_reuseport_kern(struct sk_reuseport_kern *reuse_kern,
+                                   struct sock_reuseport *reuse,
+                                   struct sock *sk, struct sk_buff *skb,
+                                   u32 hash)
+{
+       reuse_kern->skb = skb;
+       reuse_kern->sk = sk;
+       reuse_kern->selected_sk = NULL;
+       reuse_kern->data_end = skb->data + skb_headlen(skb);
+       reuse_kern->hash = hash;
+       reuse_kern->reuseport_id = reuse->reuseport_id;
+       reuse_kern->bind_inany = reuse->bind_inany;
+}
+
+struct sock *bpf_run_sk_reuseport(struct sock_reuseport *reuse, struct sock 
*sk,
+                                 struct bpf_prog *prog, struct sk_buff *skb,
+                                 u32 hash)
+{
+       struct sk_reuseport_kern reuse_kern;
+       enum sk_action action;
+
+       bpf_init_reuseport_kern(&reuse_kern, reuse, sk, skb, hash);
+       action = BPF_PROG_RUN(prog, &reuse_kern);
+
+       if (action == SK_PASS)
+               return reuse_kern.selected_sk;
+       else
+               return ERR_PTR(-ECONNREFUSED);
+}
+
+BPF_CALL_4(sk_select_reuseport, struct sk_reuseport_kern *, reuse_kern,
+          struct bpf_map *, map, void *, key, u32, flags)
+{
+       struct sock_reuseport *reuse;
+       struct sock *selected_sk;
+
+       selected_sk = map->ops->map_lookup_elem(map, key);
+       if (!selected_sk)
+               return -ENOENT;
+
+       reuse = rcu_dereference(selected_sk->sk_reuseport_cb);
+       if (!reuse)
+               /* selected_sk is unhashed (e.g. by close()) after the
+                * above map_lookup_elem().  Treat selected_sk has already
+                * been removed from the map.
+                */
+               return -ENOENT;
+
+       if (unlikely(reuse->reuseport_id != reuse_kern->reuseport_id)) {
+               struct sock *sk;
+
+               if (unlikely(!reuse_kern->reuseport_id))
+                       /* There is a small race between adding the
+                        * sk to the map and setting the
+                        * reuse_kern->reuseport_id.
+                        * Treat it as the sk has not been added to
+                        * the bpf map yet.
+                        */
+                       return -ENOENT;
+
+               sk = reuse_kern->sk;
+               if (sk->sk_protocol != selected_sk->sk_protocol)
+                       return -EPROTOTYPE;
+               else if (sk->sk_family != selected_sk->sk_family)
+                       return -EAFNOSUPPORT;
+
+               /* Catch all. Likely bound to a different sockaddr. */
+               return -EBADFD;
+       }
+
+       reuse_kern->selected_sk = selected_sk;
+
+       return 0;
+}
+
+static const struct bpf_func_proto sk_select_reuseport_proto = {
+       .func           = sk_select_reuseport,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_CONST_MAP_PTR,
+       .arg3_type      = ARG_PTR_TO_MAP_KEY,
+       .arg4_type      = ARG_ANYTHING,
+};
+
+BPF_CALL_4(sk_reuseport_load_bytes,
+          const struct sk_reuseport_kern *, reuse_kern, u32, offset,
+          void *, to, u32, len)
+{
+       return ____bpf_skb_load_bytes(reuse_kern->skb, offset, to, len);
+}
+
+static const struct bpf_func_proto sk_reuseport_load_bytes_proto = {
+       .func           = sk_reuseport_load_bytes,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_PTR_TO_UNINIT_MEM,
+       .arg4_type      = ARG_CONST_SIZE,
+};
+
+BPF_CALL_5(sk_reuseport_load_bytes_relative,
+          const struct sk_reuseport_kern *, reuse_kern, u32, offset,
+          void *, to, u32, len, u32, start_header)
+{
+       return ____bpf_skb_load_bytes_relative(reuse_kern->skb, offset, to,
+                                              len, start_header);
+}
+
+static const struct bpf_func_proto sk_reuseport_load_bytes_relative_proto = {
+       .func           = sk_reuseport_load_bytes_relative,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_PTR_TO_UNINIT_MEM,
+       .arg4_type      = ARG_CONST_SIZE,
+       .arg5_type      = ARG_ANYTHING,
+};
+
+static const struct bpf_func_proto *
+sk_reuseport_func_proto(enum bpf_func_id func_id,
+                       const struct bpf_prog *prog)
+{
+       switch (func_id) {
+       case BPF_FUNC_sk_select_reuseport:
+               return &sk_select_reuseport_proto;
+       case BPF_FUNC_skb_load_bytes:
+               return &sk_reuseport_load_bytes_proto;
+       case BPF_FUNC_skb_load_bytes_relative:
+               return &sk_reuseport_load_bytes_relative_proto;
+       default:
+               return bpf_base_func_proto(func_id);
+       }
+}
+
+static bool
+sk_reuseport_is_valid_access(int off, int size,
+                            enum bpf_access_type type,
+                            const struct bpf_prog *prog,
+                            struct bpf_insn_access_aux *info)
+{
+       const u32 size_default = sizeof(__u32);
+
+       if (off < 0 || off >= sizeof(struct sk_reuseport_md) ||
+           off % size || type != BPF_READ)
+               return false;
+
+       switch (off) {
+       case offsetof(struct sk_reuseport_md, data):
+               info->reg_type = PTR_TO_PACKET;
+               return size == sizeof(__u64);
+
+       case offsetof(struct sk_reuseport_md, data_end):
+               info->reg_type = PTR_TO_PACKET_END;
+               return size == sizeof(__u64);
+
+       case offsetof(struct sk_reuseport_md, hash):
+               return size == size_default;
+
+       /* Fields that allow narrowing */
+       case offsetof(struct sk_reuseport_md, eth_protocol):
+               if (size < FIELD_SIZEOF(struct sk_buff, protocol))
+                       return false;
+       case offsetof(struct sk_reuseport_md, ip_protocol):
+       case offsetof(struct sk_reuseport_md, bind_inany):
+       case offsetof(struct sk_reuseport_md, len):
+               bpf_ctx_record_field_size(info, size_default);
+               return bpf_ctx_narrow_access_ok(off, size, size_default);
+
+       default:
+               return false;
+       }
+}
+
+#define SK_REUSEPORT_LOAD_FIELD(F) ({                                  \
+       *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_reuseport_kern, F), \
+                             si->dst_reg, si->src_reg,                 \
+                             bpf_target_off(struct sk_reuseport_kern, F, \
+                                            FIELD_SIZEOF(struct 
sk_reuseport_kern, F), \
+                                            target_size));             \
+       })
+
+#define SK_REUSEPORT_LOAD_SKB_FIELD(SKB_FIELD)                         \
+       SOCK_ADDR_LOAD_NESTED_FIELD(struct sk_reuseport_kern,           \
+                                   struct sk_buff,                     \
+                                   skb,                                \
+                                   SKB_FIELD)
+
+#define SK_REUSEPORT_LOAD_SK_FIELD_SIZE_OFF(SK_FIELD, BPF_SIZE, EXTRA_OFF) \
+       SOCK_ADDR_LOAD_NESTED_FIELD_SIZE_OFF(struct sk_reuseport_kern,  \
+                                            struct sock,               \
+                                            sk,                        \
+                                            SK_FIELD, BPF_SIZE, EXTRA_OFF)
+
+static u32 sk_reuseport_convert_ctx_access(enum bpf_access_type type,
+                                          const struct bpf_insn *si,
+                                          struct bpf_insn *insn_buf,
+                                          struct bpf_prog *prog,
+                                          u32 *target_size)
+{
+       struct bpf_insn *insn = insn_buf;
+
+       switch (si->off) {
+       case offsetof(struct sk_reuseport_md, data):
+               SK_REUSEPORT_LOAD_SKB_FIELD(data);
+               break;
+
+       case offsetof(struct sk_reuseport_md, len):
+               SK_REUSEPORT_LOAD_SKB_FIELD(len);
+               break;
+
+       case offsetof(struct sk_reuseport_md, eth_protocol):
+               SK_REUSEPORT_LOAD_SKB_FIELD(protocol);
+               break;
+
+       case offsetof(struct sk_reuseport_md, ip_protocol):
+               BUILD_BUG_ON(hweight_long(SK_FL_PROTO_MASK) != BITS_PER_BYTE);
+               SK_REUSEPORT_LOAD_SK_FIELD_SIZE_OFF(__sk_flags_offset,
+                                                   BPF_W, 0);
+               *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK);
+               *insn++ = BPF_ALU32_IMM(BPF_RSH, si->dst_reg,
+                                       SK_FL_PROTO_SHIFT);
+               /* SK_FL_PROTO_MASK and SK_FL_PROTO_SHIFT are endian
+                * aware.  No further narrowing or masking is needed.
+                */
+               *target_size = 1;
+               break;
+
+       case offsetof(struct sk_reuseport_md, data_end):
+               SK_REUSEPORT_LOAD_FIELD(data_end);
+               break;
+
+       case offsetof(struct sk_reuseport_md, hash):
+               SK_REUSEPORT_LOAD_FIELD(hash);
+               break;
+
+       case offsetof(struct sk_reuseport_md, bind_inany):
+               SK_REUSEPORT_LOAD_FIELD(bind_inany);
+               break;
+       }
+
+       return insn - insn_buf;
+}
+
+const struct bpf_verifier_ops sk_reuseport_verifier_ops = {
+       .get_func_proto         = sk_reuseport_func_proto,
+       .is_valid_access        = sk_reuseport_is_valid_access,
+       .convert_ctx_access     = sk_reuseport_convert_ctx_access,
+};
+
+const struct bpf_prog_ops sk_reuseport_prog_ops = {
+};
+#endif /* CONFIG_INET */
diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c
index 8235f2439816..d260167f5f77 100644
--- a/net/core/sock_reuseport.c
+++ b/net/core/sock_reuseport.c
@@ -51,7 +51,7 @@ static struct sock_reuseport *__reuseport_alloc(unsigned int 
max_socks)
        return reuse;
 }
 
-int reuseport_alloc(struct sock *sk)
+int reuseport_alloc(struct sock *sk, bool bind_inany)
 {
        struct sock_reuseport *reuse;
 
@@ -63,9 +63,17 @@ int reuseport_alloc(struct sock *sk)
        /* Allocation attempts can occur concurrently via the setsockopt path
         * and the bind/hash path.  Nothing to do when we lose the race.
         */
-       if (rcu_dereference_protected(sk->sk_reuseport_cb,
-                                     lockdep_is_held(&reuseport_lock)))
+       reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
+                                         lockdep_is_held(&reuseport_lock));
+       if (reuse) {
+               /* Only set reuse->bind_inany if the bind_inany is true.
+                * Otherwise, it will overwrite the reuse->bind_inany
+                * which was set by the bind/hash path.
+                */
+               if (bind_inany)
+                       reuse->bind_inany = bind_inany;
                goto out;
+       }
 
        reuse = __reuseport_alloc(INIT_SOCKS);
        if (!reuse) {
@@ -75,6 +83,7 @@ int reuseport_alloc(struct sock *sk)
 
        reuse->socks[0] = sk;
        reuse->num_socks = 1;
+       reuse->bind_inany = bind_inany;
        rcu_assign_pointer(sk->sk_reuseport_cb, reuse);
 
 out:
@@ -101,6 +110,7 @@ static struct sock_reuseport *reuseport_grow(struct 
sock_reuseport *reuse)
        more_reuse->num_socks = reuse->num_socks;
        more_reuse->prog = reuse->prog;
        more_reuse->reuseport_id = reuse->reuseport_id;
+       more_reuse->bind_inany = reuse->bind_inany;
 
        memcpy(more_reuse->socks, reuse->socks,
               reuse->num_socks * sizeof(struct sock *));
@@ -136,12 +146,12 @@ static void reuseport_free_rcu(struct rcu_head *head)
  *  @sk2: Socket belonging to the existing reuseport group.
  *  May return ENOMEM and not add socket to group under memory pressure.
  */
-int reuseport_add_sock(struct sock *sk, struct sock *sk2)
+int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany)
 {
        struct sock_reuseport *old_reuse, *reuse;
 
        if (!rcu_access_pointer(sk2->sk_reuseport_cb)) {
-               int err = reuseport_alloc(sk2);
+               int err = reuseport_alloc(sk2, bind_inany);
 
                if (err)
                        return err;
diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c
index 33a88e045efd..dfd5009f96ef 100644
--- a/net/ipv4/inet_connection_sock.c
+++ b/net/ipv4/inet_connection_sock.c
@@ -107,6 +107,15 @@ bool inet_rcv_saddr_equal(const struct sock *sk, const 
struct sock *sk2,
 }
 EXPORT_SYMBOL(inet_rcv_saddr_equal);
 
+bool inet_rcv_saddr_any(const struct sock *sk)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+       if (sk->sk_family == AF_INET6)
+               return ipv6_addr_any(&sk->sk_v6_rcv_saddr);
+#endif
+       return !sk->sk_rcv_saddr;
+}
+
 void inet_get_local_port_range(struct net *net, int *low, int *high)
 {
        unsigned int seq;
diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c
index 3647167c8fa3..370e24463fb7 100644
--- a/net/ipv4/inet_hashtables.c
+++ b/net/ipv4/inet_hashtables.c
@@ -567,10 +567,11 @@ static int inet_reuseport_add_sock(struct sock *sk,
                    inet_csk(sk2)->icsk_bind_hash == tb &&
                    sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
                    inet_rcv_saddr_equal(sk, sk2, false))
-                       return reuseport_add_sock(sk, sk2);
+                       return reuseport_add_sock(sk, sk2,
+                                                 inet_rcv_saddr_any(sk));
        }
 
-       return reuseport_alloc(sk);
+       return reuseport_alloc(sk, inet_rcv_saddr_any(sk));
 }
 
 int __inet_hash(struct sock *sk, struct sock *osk)
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 060e841dde40..038dd7909051 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -221,11 +221,12 @@ static int udp_reuseport_add_sock(struct sock *sk, struct 
udp_hslot *hslot)
                    (sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
                    sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
                    inet_rcv_saddr_equal(sk, sk2, false)) {
-                       return reuseport_add_sock(sk, sk2);
+                       return reuseport_add_sock(sk, sk2,
+                                                 inet_rcv_saddr_any(sk));
                }
        }
 
-       return reuseport_alloc(sk);
+       return reuseport_alloc(sk, inet_rcv_saddr_any(sk));
 }
 
 /**
-- 
2.17.1

Reply via email to