Extend the multicast subscription table to track per-source state,
enabling SSM (S,G) forwarding instead of group-only ASM forwarding.

Data structures:
- Add ovpn_mcast_source to track individual source addresses
- Extend ovpn_mcast_sub with filter_mode (INCLUDE/EXCLUDE) and a
  source list

Subscription API:
- Add ovpn_mcast_sub_update() to create or update a subscription
  with a full source list and filter mode
- ovpn_mcast_join() becomes a thin wrapper around sub_update()
  (EXCLUDE mode with empty source list = ASM join)
- Add ovpn_mcast_srcs_update() for incremental source merging:
  ALLOW_NEW adds sources in INCLUDE mode and removes them in EXCLUDE
  mode; BLOCK_OLD does the opposite
- Empty INCLUDE subscriptions are automatically deleted when
  BLOCK_OLD removes the last source

RX path (snooping):
- IGMPv3 and MLDv2 parsers now extract source lists from reports
  and pass them to sub_update()
- All record types are handled: MODE_IS_*, CHANGE_TO_*,
  ALLOW_NEW_SOURCES, BLOCK_OLD_SOURCES

TX path (forwarding):
- Add ovpn_mcast_src_allowed() to evaluate a source against a peer's
  filter mode and source list
- ovpn_peer_list_get_by_mcast_group() now takes a source address
  and only returns peers whose subscription allows the source
- ASM backward compatibility preserved: EXCLUDE with empty source
  list allows all sources

Signed-off-by: Marco Baffo <[email protected]>
---
 drivers/net/ovpn/mcast.c | 334 +++++++++++++++++++++++++++++++++------
 drivers/net/ovpn/mcast.h |  14 +-
 drivers/net/ovpn/peer.c  |   7 +-
 3 files changed, 305 insertions(+), 50 deletions(-)

diff --git a/drivers/net/ovpn/mcast.c b/drivers/net/ovpn/mcast.c
index 1e436a6721bb..74b791ad7489 100644
--- a/drivers/net/ovpn/mcast.c
+++ b/drivers/net/ovpn/mcast.c
@@ -17,9 +17,16 @@ struct ovpn_mcast_group {
        struct list_head subs;
 };
 
+struct ovpn_mcast_source {
+       struct list_head list;
+       struct in6_addr addr;
+};
+
 struct ovpn_mcast_sub {
        struct list_head list;
        struct ovpn_peer *peer;
+       enum ovpn_mcast_filter_mode filter_mode;
+       struct list_head sources;
 };
 
 static inline u32 ovpn_mcast_hash(const struct in6_addr *group_addr)
@@ -47,10 +54,21 @@ ovpn_mcast_group_find(const struct ovpn_priv *ovpn, const 
struct in6_addr *group
        return NULL;
 }
 
+static void ovpn_mcast_srcs_del_all(struct list_head *srcs)
+{
+       struct ovpn_mcast_source *src, *next;
+
+       list_for_each_entry_safe(src, next, srcs, list) {
+               list_del(&src->list);
+               kfree(src);
+       }
+}
+
 static struct ovpn_peer *ovpn_mcast_sub_del(struct ovpn_mcast_sub *sub)
 {
        struct ovpn_peer *peer = sub->peer;
 
+       ovpn_mcast_srcs_del_all(&sub->sources);
        list_del(&sub->list);
        kfree(sub);
        return peer;
@@ -85,20 +103,138 @@ void ovpn_mcast_cleanup(struct ovpn_priv *ovpn)
        }
 }
 
+static void ovpn_mcast_srcs_del(struct ovpn_mcast_sub *sub,
+                               const struct in6_addr *sources,
+                               const unsigned int nsrcs)
+{
+       struct ovpn_mcast_source *src, *next;
+       unsigned int i;
+
+       for (i = 0; i < nsrcs; i++) {
+               list_for_each_entry_safe(src, next, &sub->sources, list) {
+                       if (ipv6_addr_equal(&src->addr, &sources[i])) {
+                               list_del(&src->list);
+                               kfree(src);
+                               break;
+                       }
+               }
+       }
+}
+
+static bool ovpn_mcast_source_exists(const struct ovpn_mcast_sub *sub,
+                                    const struct in6_addr *addr)
+{
+       struct ovpn_mcast_source *src;
+
+       list_for_each_entry(src, &sub->sources, list) {
+               if (ipv6_addr_equal(&src->addr, addr))
+                       return true;
+       }
+       return false;
+}
+
+static void ovpn_mcast_srcs_add(struct ovpn_mcast_sub *sub,
+                               const struct in6_addr *sources,
+                               const unsigned int nsrcs)
+{
+       struct ovpn_mcast_source *src;
+       unsigned int i;
+
+       for (i = 0; i < nsrcs; i++) {
+               if (ovpn_mcast_source_exists(sub, &sources[i]))
+                       continue;
+
+               src = kzalloc_obj(*src, GFP_ATOMIC);
+               if (!src)
+                       break;
+               src->addr = sources[i];
+               list_add_tail(&src->list, &sub->sources);
+       }
+}
+
+static struct ovpn_peer *ovpn_mcast_srcs_update(struct ovpn_mcast_sub *sub,
+                                               const enum 
ovpn_mcast_filter_mode msg_mode,
+                                               const struct in6_addr *sources,
+                                               const unsigned int nsrcs)
+{
+       if (!sources || !nsrcs)
+               return NULL;
+
+       /* ALLOW_NEW: add in INCLUDE, del in EXCLUDE.
+        * BLOCK_OLD: del in INCLUDE, add in EXCLUDE.
+        */
+       if (sub->filter_mode == msg_mode) {
+               ovpn_mcast_srcs_add(sub, sources, nsrcs);
+       } else {
+               ovpn_mcast_srcs_del(sub, sources, nsrcs);
+               if (sub->filter_mode == OVPN_MCAST_INCLUDE &&
+                   list_empty(&sub->sources))
+                       return ovpn_mcast_sub_del(sub);
+       }
+       return NULL;
+}
+
+static bool ovpn_mcast_sub_init(struct ovpn_mcast_sub **subp,
+                               struct ovpn_peer *peer,
+                               const enum ovpn_mcast_filter_mode mode,
+                               struct ovpn_mcast_group *group)
+{
+       struct ovpn_mcast_sub *sub;
+
+       sub = kzalloc_obj(*sub, GFP_ATOMIC);
+       if (unlikely(!sub))
+               return false;
+
+       if (!ovpn_peer_hold(peer)) {
+               kfree(sub);
+               return false;
+       }
+
+       sub->peer = peer;
+       sub->filter_mode = mode;
+       INIT_LIST_HEAD(&sub->sources);
+       list_add_tail(&sub->list, &group->subs);
+       *subp = sub;
+       return true;
+}
+
 /**
- * ovpn_mcast_join - add a peer to a multicast group
+ * ovpn_mcast_sub_update - create, replace, or incrementally update a 
multicast subscription
  * @ovpn: the ovpn instance
- * @peer: the peer joining the group
- * @group_addr: the multicast group address (IPv4-mapped IPv6 for IPv4 groups)
+ * @peer: the peer whose subscription is being updated
+ * @group_addr: the multicast group address
+ * @mode: the filter mode (INCLUDE or EXCLUDE)
+ * @sources: array of source addresses to add or remove
+ * @nsrcs: number of sources in @sources
+ * @incremental_update: if true, merge sources into existing state;
+ *                     if false, replace state entirely
  *
- * Creates the group if it does not exist and adds a subscription for @peer.
- * If the peer is already subscribed, returns success without doing anything.
+ * When @incremental_update is false the subscription is fully replaced with
+ * the given @mode and @sources. An empty source list with INCLUDE mode is
+ * equivalent to leaving the group; with EXCLUDE mode it is an ASM join
+ * (receive all sources).
+ *
+ * When @incremental_update is true the sources are merged: they are added
+ * to the list when @mode matches the current filter mode, or removed when
+ * it differs. ALLOW_NEW maps to INCLUDE; BLOCK_OLD maps to EXCLUDE. If a
+ * BLOCK_OLD operation removes the last source from an INCLUDE subscription,
+ * the subscription is destroyed.
+ *
+ * If no subscription exists for @peer on @group_addr one is created. If the
+ * group does not exist it is created.
+ *
+ * All updates are atomic under @ovpn->lock.
  */
-void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
-                    const struct in6_addr *group_addr)
+void ovpn_mcast_sub_update(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+                          const struct in6_addr *group_addr,
+                          const enum ovpn_mcast_filter_mode mode,
+                          const struct in6_addr *sources,
+                          const unsigned int nsrcs,
+                          const bool incremental_update)
 {
        struct ovpn_mcast_group *group;
        struct ovpn_mcast_sub *sub;
+       struct ovpn_peer *peer_to_put = NULL;
 
        if (!ovpn_mcast_addr_valid(group_addr))
                return;
@@ -117,19 +253,47 @@ void ovpn_mcast_join(struct ovpn_priv *ovpn, struct 
ovpn_peer *peer,
        }
 
        list_for_each_entry(sub, &group->subs, list) {
-               if (sub->peer == peer)
+               if (sub->peer != peer)
+                       continue;
+               if (incremental_update) {
+                       peer_to_put = ovpn_mcast_srcs_update(sub, mode, 
sources, nsrcs);
+                       ovpn_mcast_group_try_del(group);
                        goto end;
+               } else {
+                       sub->filter_mode = mode;
+                       ovpn_mcast_srcs_del_all(&sub->sources);
+                       goto add_sources;
+               }
        }
 
-       sub = kzalloc_obj(*sub, GFP_ATOMIC);
-       if (unlikely(!sub))
+       if (!ovpn_mcast_sub_init(&sub, peer, mode, group)) {
+               ovpn_mcast_group_try_del(group);
                goto end;
-
-       sub->peer = peer;
-       ovpn_peer_hold(peer);
-       list_add_tail(&sub->list, &group->subs);
+       }
+add_sources:
+       if (sources && nsrcs)
+               ovpn_mcast_srcs_add(sub, sources, nsrcs);
 end:
        spin_unlock_bh(&ovpn->lock);
+
+       if (peer_to_put)
+               ovpn_peer_put(peer_to_put);
+}
+
+/**
+ * ovpn_mcast_join - add a peer to a multicast group
+ * @ovpn: the ovpn instance
+ * @peer: the peer joining the group
+ * @group_addr: the multicast group address (IPv4-mapped IPv6 for IPv4 groups)
+ *
+ * Creates the group if it does not exist and adds a subscription for @peer.
+ * If the peer is already subscribed, returns without doing anything.
+ */
+void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+                    const struct in6_addr *group_addr)
+{
+       ovpn_mcast_sub_update(ovpn, peer, group_addr, OVPN_MCAST_EXCLUDE,
+                             NULL, 0, false);
 }
 
 /**
@@ -202,20 +366,36 @@ void ovpn_mcast_leave_all(struct ovpn_peer *peer)
                ovpn_peer_put(peer);
 }
 
+static bool ovpn_mcast_src_allowed(const struct ovpn_mcast_sub *sub,
+                                  const struct in6_addr *src_addr)
+{
+       struct ovpn_mcast_source *src;
+
+       list_for_each_entry(src, &sub->sources, list) {
+               if (ipv6_addr_equal(&src->addr, src_addr))
+                       return sub->filter_mode == OVPN_MCAST_INCLUDE;
+       }
+       return sub->filter_mode == OVPN_MCAST_EXCLUDE;
+}
+
 /**
  * ovpn_peer_list_get_by_mcast_group - retrieve peers subscribed to a 
multicast group
  * @ovpn: the ovpn instance to search
  * @group_addr: the multicast group address to look up
  * @list: the lockless list to append matching peers to
  *
- * Searches for the multicast group identified by @group_addr and appends all
- * subscribed peers to @list, acquiring a reference on each one.
+ * @src: the source address to match against per-peer source filters
+ *
+ * Searches for the multicast group identified by @group_addr and appends
+ * subscribed peers whose source filter allows @src to @list, acquiring a
+ * reference on each one.
  *
  * Return: false if no peer was found, true otherwise
  */
 bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv *ovpn,
                                       const struct in6_addr *group_addr,
-                                      struct llist_head *list)
+                                      struct llist_head *list,
+                                      const struct in6_addr *src)
 {
        struct ovpn_mcast_group *group;
        struct ovpn_mcast_sub *sub;
@@ -225,7 +405,8 @@ bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv 
*ovpn,
        group = ovpn_mcast_group_find(ovpn, group_addr);
        if (group) {
                list_for_each_entry(sub, &group->subs, list) {
-                       if (ovpn_peer_hold(sub->peer))
+                       if (ovpn_mcast_src_allowed(sub, src) &&
+                           ovpn_peer_hold(sub->peer))
                                llist_add(&sub->peer->mcast_entry, list);
                }
        }
@@ -305,18 +486,47 @@ static bool ovpn_mcast_snoop_mldv2(struct ovpn_peer 
*peer, struct sk_buff *skb,
                /* recompute grec after potential head reallocation */
                grec = (struct mld2_grec *)(skb_network_header(skb) + offset - 
rec_len);
 
-               /* In MLDv2 ASM, EXCLUDE mode with an empty source list means
-                * "exclude nothing, receive everything" -> JOIN.
-                * INCLUDE mode with an empty source list means
-                * "include nothing, receive nothing" -> LEAVE.
-                * See RFC 3810, section 4.
-                */
-               if (nsrcs == 0 &&
-                   (grec->grec_type == MLD2_CHANGE_TO_INCLUDE ||
-                    grec->grec_type == MLD2_MODE_IS_INCLUDE)) {
-                       ovpn_mcast_leave(peer->ovpn, peer, &grec->grec_mca);
-               } else {
-                       ovpn_mcast_join(peer->ovpn, peer, &grec->grec_mca);
+               switch (grec->grec_type) {
+               case MLD2_MODE_IS_INCLUDE:
+               case MLD2_CHANGE_TO_INCLUDE:
+                       if (nsrcs == 0)
+                               ovpn_mcast_leave(peer->ovpn, peer,
+                                                &grec->grec_mca);
+                       else
+                               ovpn_mcast_sub_update(peer->ovpn, peer,
+                                                     &grec->grec_mca,
+                                                     OVPN_MCAST_INCLUDE,
+                                                     grec->grec_src, nsrcs,
+                                                     false);
+                       break;
+               case MLD2_MODE_IS_EXCLUDE:
+               case MLD2_CHANGE_TO_EXCLUDE:
+                       if (nsrcs == 0)
+                               ovpn_mcast_join(peer->ovpn, peer,
+                                               &grec->grec_mca);
+                       else
+                               ovpn_mcast_sub_update(peer->ovpn, peer,
+                                                     &grec->grec_mca,
+                                                     OVPN_MCAST_EXCLUDE,
+                                                     grec->grec_src, nsrcs,
+                                                     false);
+                       break;
+               case MLD2_ALLOW_NEW_SOURCES:
+                       if (nsrcs)
+                               ovpn_mcast_sub_update(peer->ovpn, peer,
+                                                     &grec->grec_mca,
+                                                     OVPN_MCAST_INCLUDE,
+                                                     grec->grec_src, nsrcs,
+                                                     true);
+                       break;
+               case MLD2_BLOCK_OLD_SOURCES:
+                       if (nsrcs)
+                               ovpn_mcast_sub_update(peer->ovpn, peer,
+                                                     &grec->grec_mca,
+                                                     OVPN_MCAST_EXCLUDE,
+                                                     grec->grec_src, nsrcs,
+                                                     true);
+                       break;
                }
        }
 
@@ -381,9 +591,9 @@ static bool ovpn_mcast_snoop_igmpv3(struct ovpn_peer *peer, 
struct sk_buff *skb,
                                    unsigned int offset, const int ngrec)
 {
        struct igmpv3_grec *grec;
-       struct in6_addr addr6;
+       struct in6_addr addr6, *srcs = NULL;
        int i;
-       unsigned int rec_len;
+       unsigned int j, rec_len;
        __u16 nsrcs;
 
        for (i = 0; i < ngrec; i++) {
@@ -403,21 +613,53 @@ static bool ovpn_mcast_snoop_igmpv3(struct ovpn_peer 
*peer, struct sk_buff *skb,
                /* recompute grec after potential head reallocation */
                grec = (struct igmpv3_grec *)(skb_network_header(skb) + offset 
- rec_len);
 
-               /* In IGMPv3 ASM, EXCLUDE mode with an empty source list means
-                * "exclude nothing, receive everything" -> JOIN.
-                * INCLUDE mode with an empty source list means
-                * "include nothing, receive nothing" -> LEAVE.
-                * See RFC 3376, section 3.
-                */
-               if (nsrcs == 0 &&
-                   (grec->grec_type == IGMPV3_CHANGE_TO_INCLUDE ||
-                    grec->grec_type == IGMPV3_MODE_IS_INCLUDE)) {
-                       ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
-                       ovpn_mcast_leave(peer->ovpn, peer, &addr6);
-               } else {
-                       ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
-                       ovpn_mcast_join(peer->ovpn, peer, &addr6);
+               ipv6_addr_set_v4mapped(grec->grec_mca, &addr6);
+
+               if (nsrcs > 0) {
+                       srcs = kcalloc(nsrcs, sizeof(*srcs), GFP_ATOMIC);
+                       if (!srcs)
+                               return false;
+
+                       for (j = 0; j < nsrcs; j++)
+                               ipv6_addr_set_v4mapped(grec->grec_src[j],
+                                                      &srcs[j]);
                }
+
+               switch (grec->grec_type) {
+               case IGMPV3_MODE_IS_INCLUDE:
+               case IGMPV3_CHANGE_TO_INCLUDE:
+                       if (nsrcs == 0)
+                               ovpn_mcast_leave(peer->ovpn, peer, &addr6);
+                       else
+                               ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+                                                     OVPN_MCAST_INCLUDE, srcs,
+                                                     nsrcs, false);
+                       break;
+               case IGMPV3_MODE_IS_EXCLUDE:
+               case IGMPV3_CHANGE_TO_EXCLUDE:
+                       if (nsrcs == 0)
+                               ovpn_mcast_join(peer->ovpn, peer, &addr6);
+                       else
+                               ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+                                                     OVPN_MCAST_EXCLUDE, srcs,
+                                                     nsrcs, false);
+                       break;
+               case IGMPV3_ALLOW_NEW_SOURCES:
+                       if (nsrcs)
+                               ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+                                                     OVPN_MCAST_INCLUDE, srcs,
+                                                     nsrcs, true);
+                       break;
+               case IGMPV3_BLOCK_OLD_SOURCES:
+                       if (nsrcs)
+                               ovpn_mcast_sub_update(peer->ovpn, peer, &addr6,
+                                                     OVPN_MCAST_EXCLUDE, srcs,
+                                                     nsrcs, true);
+                       break;
+               }
+
+               kfree(srcs);
+               srcs = NULL;
        }
 
        return true;
diff --git a/drivers/net/ovpn/mcast.h b/drivers/net/ovpn/mcast.h
index 9e06e893a355..b41812534d58 100644
--- a/drivers/net/ovpn/mcast.h
+++ b/drivers/net/ovpn/mcast.h
@@ -13,15 +13,27 @@ struct in6_addr;
 struct llist_head;
 struct sk_buff;
 
+enum ovpn_mcast_filter_mode {
+       OVPN_MCAST_EXCLUDE,
+       OVPN_MCAST_INCLUDE,
+};
+
 void ovpn_mcast_cleanup(struct ovpn_priv *ovpn);
 void ovpn_mcast_join(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
                     const struct in6_addr *group_addr);
 void ovpn_mcast_leave(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
                      const struct in6_addr *group_addr);
+void ovpn_mcast_sub_update(struct ovpn_priv *ovpn, struct ovpn_peer *peer,
+                          const struct in6_addr *group_addr,
+                          const enum ovpn_mcast_filter_mode mode,
+                          const struct in6_addr *sources,
+                          const unsigned int nsrcs,
+                          const bool incremental_update);
 void ovpn_mcast_leave_all(struct ovpn_peer *peer);
 bool ovpn_peer_list_get_by_mcast_group(struct ovpn_priv *ovpn,
                                       const struct in6_addr *group_addr,
-                                      struct llist_head *list);
+                                      struct llist_head *list,
+                                      const struct in6_addr *src);
 bool ovpn_mcast_is_control(struct sk_buff *skb);
 bool ovpn_mcast_snoop_skb(struct ovpn_peer *peer, struct sk_buff *skb);
 
diff --git a/drivers/net/ovpn/peer.c b/drivers/net/ovpn/peer.c
index a9728a157210..3fc69c3cecc0 100644
--- a/drivers/net/ovpn/peer.c
+++ b/drivers/net/ovpn/peer.c
@@ -751,7 +751,7 @@ void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, 
struct sk_buff *skb,
 {
        struct ovpn_peer *peer = NULL;
        unsigned int addr_type;
-       struct in6_addr addr6;
+       struct in6_addr addr6, src;
        __be32 addr4;
 
        /* in P2P mode, no matter the destination, packets are always sent to
@@ -779,7 +779,8 @@ void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, 
struct sk_buff *skb,
                addr_type = inet_dev_addr_type(dev_net(ovpn->dev), ovpn->dev, 
addr4);
                if (addr_type == RTN_MULTICAST) {
                        ipv6_addr_set_v4mapped(addr4, &addr6);
-                       if (!ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, 
list) &&
+                       ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr, &src);
+                       if (!ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, 
list, &src) &&
                            ovpn_mcast_is_control(skb)) {
                                ovpn_peer_list_get_all(ovpn, list);
                        }
@@ -797,7 +798,7 @@ void ovpn_peer_list_get_by_dst(struct ovpn_priv *ovpn, 
struct sk_buff *skb,
 
                rcu_read_unlock();
                if (ipv6_addr_is_multicast(&addr6) &&
-                   !ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list) &&
+                   !ovpn_peer_list_get_by_mcast_group(ovpn, &addr6, list, 
&ipv6_hdr(skb)->saddr) &&
                    ovpn_mcast_is_control(skb)) {
                        ovpn_peer_list_get_all(ovpn, list);
                }
-- 
2.43.0



_______________________________________________
Openvpn-devel mailing list
[email protected]
https://lists.sourceforge.net/lists/listinfo/openvpn-devel

Reply via email to