Same as for udp, split the socket lookup into two phases and let the BPF
inet_lookup program select the receiving socket.

Suggested-by: Marek Majkowski <[email protected]>
Reviewed-by: Lorenz Bauer <[email protected]>
Signed-off-by: Jakub Sitnicki <[email protected]>
---
 net/ipv6/udp.c | 42 ++++++++++++++++++++++++++++++------------
 1 file changed, 30 insertions(+), 12 deletions(-)

diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 16ef2303bd8d..7380cf57e88c 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -101,7 +101,7 @@ void udp_v6_rehash(struct sock *sk)
 static int compute_score(struct sock *sk, struct net *net,
                         const struct in6_addr *saddr, __be16 sport,
                         const struct in6_addr *daddr, unsigned short hnum,
-                        int dif, int sdif)
+                        int dif, int sdif, unsigned char state)
 {
        int score;
        struct inet_sock *inet;
@@ -112,6 +112,9 @@ static int compute_score(struct sock *sk, struct net *net,
            sk->sk_family != PF_INET6)
                return -1;
 
+       if (state && sk->sk_state != state)
+               return -1;
+
        if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr))
                return -1;
 
@@ -146,7 +149,7 @@ static struct sock *udp6_lib_lookup2(struct net *net,
                const struct in6_addr *saddr, __be16 sport,
                const struct in6_addr *daddr, unsigned int hnum,
                int dif, int sdif, struct udp_hslot *hslot2,
-               struct sk_buff *skb)
+               struct sk_buff *skb, unsigned char state)
 {
        struct sock *sk, *result;
        int score, badness;
@@ -156,7 +159,7 @@ static struct sock *udp6_lib_lookup2(struct net *net,
        badness = -1;
        udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
                score = compute_score(sk, net, saddr, sport,
-                                     daddr, hnum, dif, sdif);
+                                     daddr, hnum, dif, sdif, state);
                if (score > badness) {
                        if (sk->sk_reuseport) {
                                hash = udp6_ehashfn(net, daddr, hnum,
@@ -190,19 +193,34 @@ struct sock *__udp6_lib_lookup(struct net *net,
        slot2 = hash2 & udptable->mask;
        hslot2 = &udptable->hash2[slot2];
 
+       /* Lookup connected sockets */
        result = udp6_lib_lookup2(net, saddr, sport,
                                  daddr, hnum, dif, sdif,
-                                 hslot2, skb);
-       if (!result) {
-               hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
-               slot2 = hash2 & udptable->mask;
+                                 hslot2, skb, TCP_ESTABLISHED);
+       if (result)
+               goto done;
 
-               hslot2 = &udptable->hash2[slot2];
+       /* Lookup redirect from BPF */
+       result = inet6_lookup_run_bpf(net, udptable->protocol,
+                                     saddr, sport, daddr, hnum);
+       if (result)
+               goto done;
 
-               result = udp6_lib_lookup2(net, saddr, sport,
-                                         &in6addr_any, hnum, dif, sdif,
-                                         hslot2, skb);
-       }
+       /* Lookup bound sockets */
+       result = udp6_lib_lookup2(net, saddr, sport,
+                                 daddr, hnum, dif, sdif,
+                                 hslot2, skb, 0);
+       if (result)
+               goto done;
+
+       hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum);
+       slot2 = hash2 & udptable->mask;
+       hslot2 = &udptable->hash2[slot2];
+
+       result = udp6_lib_lookup2(net, saddr, sport,
+                                 &in6addr_any, hnum, dif, sdif,
+                                 hslot2, skb, 0);
+done:
        if (IS_ERR(result))
                return NULL;
        return result;
-- 
2.20.1

Reply via email to