Right now UDP outgoing traffic is kmem-auto-charged into cg
kmem. Incoming traffic is not, but it has tcp-like memory 
scheduler (but simpler, with just one limit). So here's the
per-cgroup UDP read buffers limiting in the same was as TCP
is done.

Signed-off-by: Pavel Emelyanov <xe...@parallels.com>

---
 include/net/udp.h            |   1 +
 include/net/udp_memcontrol.h |  13 +++
 mm/memcontrol.c              |  11 +++
 net/ipv4/Makefile            |   1 +
 net/ipv4/udp.c               |  18 ++++
 net/ipv4/udp_memcontrol.c    | 221 +++++++++++++++++++++++++++++++++++++++++++
 net/ipv6/udp.c               |   5 +
 7 files changed, 270 insertions(+)
 create mode 100644 include/net/udp_memcontrol.h
 create mode 100644 net/ipv4/udp_memcontrol.c

diff --git a/include/net/udp.h b/include/net/udp.h
index 74c10ec..2ad7d90 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -273,4 +273,5 @@ extern void udp_encap_enable(void);
 #if IS_ENABLED(CONFIG_IPV6)
 extern void udpv6_encap_enable(void);
 #endif
+extern int udp_init_sock(struct sock *sk);
 #endif /* _UDP_H */
diff --git a/include/net/udp_memcontrol.h b/include/net/udp_memcontrol.h
new file mode 100644
index 0000000..34a2cba
--- /dev/null
+++ b/include/net/udp_memcontrol.h
@@ -0,0 +1,13 @@
+#ifndef _UDP_MEMCG_H
+#define _UDP_MEMCG_H
+
+struct udp_memcontrol {
+       struct cg_proto cg_proto;
+       struct res_counter udp_memory_allocated;
+       long udp_prot_mem[3];
+};
+
+struct cg_proto *udp_proto_cgroup(struct mem_cgroup *memcg);
+int udp_init_cgroup(struct mem_cgroup *memcg, struct cgroup_subsys *ss);
+void udp_destroy_cgroup(struct mem_cgroup *memcg);
+#endif /* _UDP_MEMCG_H */
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index d38868c..4d0a756 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -57,6 +57,7 @@
 #include <net/sock.h>
 #include <net/ip.h>
 #include <net/tcp_memcontrol.h>
+#include <net/udp_memcontrol.h>
 #include "slab.h"
 
 #include <asm/uaccess.h>
@@ -331,6 +332,7 @@ struct mem_cgroup {
        atomic_t        dead_count;
 #if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_INET)
        struct tcp_memcontrol tcp_mem;
+       struct udp_memcontrol udp_mem;
 #endif
 #if defined(CONFIG_MEMCG_KMEM)
         /* Index in the kmem_cache->memcg_params.memcg_caches array */
@@ -551,6 +553,15 @@ struct cg_proto *tcp_proto_cgroup(struct mem_cgroup *memcg)
 }
 EXPORT_SYMBOL(tcp_proto_cgroup);
 
+struct cg_proto *udp_proto_cgroup(struct mem_cgroup *memcg)
+{
+       if (!memcg || mem_cgroup_is_root(memcg))
+               return NULL;
+
+       return &memcg->udp_mem.cg_proto;
+}
+EXPORT_SYMBOL(udp_proto_cgroup);
+
 static void disarm_sock_keys(struct mem_cgroup *memcg)
 {
        if (!memcg_proto_activated(&memcg->tcp_mem.cg_proto))
diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile
index f8c49ce..4b8119d 100644
--- a/net/ipv4/Makefile
+++ b/net/ipv4/Makefile
@@ -52,6 +52,7 @@ obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o
 obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o
 obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o
 obj-$(CONFIG_MEMCG_KMEM) += tcp_memcontrol.o
+obj-$(CONFIG_MEMCG_KMEM) += udp_memcontrol.o
 obj-$(CONFIG_NETLABEL) += cipso_ipv4.o
 
 obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 687731b..b0352b0 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -111,6 +111,7 @@
 #include <trace/events/skb.h>
 #include <net/busy_poll.h>
 #include "udp_impl.h"
+#include <net/udp_memcontrol.h>
 
 struct udp_table udp_table __read_mostly;
 EXPORT_SYMBOL(udp_table);
@@ -1786,6 +1787,7 @@ void udp_destroy_sock(struct sock *sk)
                if (encap_destroy)
                        encap_destroy(sk);
        }
+       sock_release_memcg(sk);
 }
 
 /*
@@ -1984,6 +1986,16 @@ unsigned int udp_poll(struct file *file, struct socket 
*sock, poll_table *wait)
 }
 EXPORT_SYMBOL(udp_poll);
 
+int udp_init_sock(struct sock *sk)
+{
+       local_bh_disable();
+       sock_update_memcg(sk);
+       local_bh_enable();
+
+       return 0;
+}
+EXPORT_SYMBOL(udp_init_sock);
+
 struct proto udp_prot = {
        .name              = "UDP",
        .owner             = THIS_MODULE,
@@ -1991,6 +2003,7 @@ struct proto udp_prot = {
        .connect           = ip4_datagram_connect,
        .disconnect        = udp_disconnect,
        .ioctl             = udp_ioctl,
+       .init              = udp_init_sock,
        .destroy           = udp_destroy_sock,
        .setsockopt        = udp_setsockopt,
        .getsockopt        = udp_getsockopt,
@@ -2015,6 +2028,11 @@ struct proto udp_prot = {
        .compat_getsockopt = compat_udp_getsockopt,
 #endif
        .clear_sk          = sk_prot_clear_portaddr_nulls,
+#ifdef CONFIG_MEMCG_KMEM
+       .init_cgroup            = udp_init_cgroup,
+       .destroy_cgroup         = udp_destroy_cgroup,
+       .proto_cgroup           = udp_proto_cgroup,
+#endif
 };
 EXPORT_SYMBOL(udp_prot);
 
diff --git a/net/ipv4/udp_memcontrol.c b/net/ipv4/udp_memcontrol.c
new file mode 100644
index 0000000..d9f7977
--- /dev/null
+++ b/net/ipv4/udp_memcontrol.c
@@ -0,0 +1,221 @@
+#include <net/udp.h>
+#include <net/udp_memcontrol.h>
+#include <net/sock.h>
+#include <net/ip.h>
+#include <linux/nsproxy.h>
+#include <linux/memcontrol.h>
+#include <linux/module.h>
+
+/*
+ * The below code is copied from tcp_memcontrol.c with
+ * s/tcp/udp/g and knowledge that udp doesn't need mem
+ * pressure state and sockets_allocated counter.
+ */
+
+static inline struct udp_memcontrol *udp_from_cgproto(struct cg_proto 
*cg_proto)
+{
+       return container_of(cg_proto, struct udp_memcontrol, cg_proto);
+}
+
+int udp_init_cgroup(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
+{
+       /*
+        * The root cgroup does not use res_counters, but rather,
+        * rely on the data already collected by the network
+        * subsystem
+        */
+       struct res_counter *res_parent = NULL;
+       struct cg_proto *cg_proto, *parent_cg;
+       struct udp_memcontrol *udp;
+       struct mem_cgroup *parent = parent_mem_cgroup(memcg);
+
+       cg_proto = udp_prot.proto_cgroup(memcg);
+       if (!cg_proto)
+               return 0;
+
+       udp = udp_from_cgproto(cg_proto);
+
+       udp->udp_prot_mem[0] = sysctl_udp_mem[0];
+       udp->udp_prot_mem[1] = sysctl_udp_mem[1];
+       udp->udp_prot_mem[2] = sysctl_udp_mem[2];
+
+       parent_cg = udp_prot.proto_cgroup(parent);
+       if (parent_cg)
+               res_parent = parent_cg->memory_allocated;
+
+       res_counter_init(&udp->udp_memory_allocated, res_parent);
+
+       cg_proto->sysctl_mem = udp->udp_prot_mem;
+       cg_proto->memory_allocated = &udp->udp_memory_allocated;
+       cg_proto->memcg = memcg;
+
+       return 0;
+}
+
+void udp_destroy_cgroup(struct mem_cgroup *memcg)
+{
+}
+
+static int udp_update_limit(struct mem_cgroup *memcg, u64 val)
+{
+       struct udp_memcontrol *udp;
+       struct cg_proto *cg_proto;
+       u64 old_lim;
+       int i;
+       int ret;
+
+       cg_proto = udp_prot.proto_cgroup(memcg);
+       if (!cg_proto)
+               return -EINVAL;
+
+       if (val > RESOURCE_MAX)
+               val = RESOURCE_MAX;
+
+       udp = udp_from_cgproto(cg_proto);
+
+       old_lim = res_counter_read_u64(&udp->udp_memory_allocated, RES_LIMIT);
+       ret = res_counter_set_limit(&udp->udp_memory_allocated, val);
+       if (ret)
+               return ret;
+
+       for (i = 0; i < 3; i++)
+               udp->udp_prot_mem[i] = min_t(long, val >> PAGE_SHIFT, 
sysctl_udp_mem[i]);
+
+       if (val == RESOURCE_MAX)
+               clear_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
+       else if (val != RESOURCE_MAX) {
+               if (!test_and_set_bit(MEMCG_SOCK_ACTIVATED, &cg_proto->flags))
+                       static_key_slow_inc(&memcg_socket_limit_enabled);
+               set_bit(MEMCG_SOCK_ACTIVE, &cg_proto->flags);
+       }
+
+       return 0;
+}
+
+static int udp_cgroup_write(struct cgroup *cont, struct cftype *cft,
+                           const char *buffer)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_cont(cont);
+       unsigned long long val;
+       int ret = 0;
+
+       switch (cft->private) {
+       case RES_LIMIT:
+               /* see memcontrol.c */
+               ret = res_counter_memparse_write_strategy(buffer, &val);
+               if (ret)
+                       break;
+               ret = udp_update_limit(memcg, val);
+               break;
+       default:
+               ret = -EINVAL;
+               break;
+       }
+       return ret;
+}
+
+static u64 udp_read_stat(struct mem_cgroup *memcg, int type, u64 default_val)
+{
+       struct udp_memcontrol *udp;
+       struct cg_proto *cg_proto;
+
+       cg_proto = udp_prot.proto_cgroup(memcg);
+       if (!cg_proto)
+               return default_val;
+
+       udp = udp_from_cgproto(cg_proto);
+       return res_counter_read_u64(&udp->udp_memory_allocated, type);
+}
+
+static u64 udp_read_usage(struct mem_cgroup *memcg)
+{
+       struct udp_memcontrol *udp;
+       struct cg_proto *cg_proto;
+
+       cg_proto = udp_prot.proto_cgroup(memcg);
+       if (!cg_proto)
+               return atomic_long_read(&udp_memory_allocated) << PAGE_SHIFT;
+
+       udp = udp_from_cgproto(cg_proto);
+       return res_counter_read_u64(&udp->udp_memory_allocated, RES_USAGE);
+}
+
+static u64 udp_cgroup_read(struct cgroup *cont, struct cftype *cft)
+{
+       struct mem_cgroup *memcg = mem_cgroup_from_cont(cont);
+       u64 val;
+
+       switch (cft->private) {
+       case RES_LIMIT:
+               val = udp_read_stat(memcg, RES_LIMIT, RESOURCE_MAX);
+               break;
+       case RES_USAGE:
+               val = udp_read_usage(memcg);
+               break;
+       case RES_FAILCNT:
+       case RES_MAX_USAGE:
+               val = udp_read_stat(memcg, cft->private, 0);
+               break;
+       default:
+               BUG();
+       }
+       return val;
+}
+
+static int udp_cgroup_reset(struct cgroup *cont, unsigned int event)
+{
+       struct mem_cgroup *memcg;
+       struct udp_memcontrol *udp;
+       struct cg_proto *cg_proto;
+
+       memcg = mem_cgroup_from_cont(cont);
+       cg_proto = udp_prot.proto_cgroup(memcg);
+       if (!cg_proto)
+               return 0;
+       udp = udp_from_cgproto(cg_proto);
+
+       switch (event) {
+       case RES_MAX_USAGE:
+               res_counter_reset_max(&udp->udp_memory_allocated);
+               break;
+       case RES_FAILCNT:
+               res_counter_reset_failcnt(&udp->udp_memory_allocated);
+               break;
+       }
+
+       return 0;
+}
+
+static struct cftype udp_files[] = {
+       {
+               .name = "kmem.udp.limit_in_bytes",
+               .write_string = udp_cgroup_write,
+               .read_u64 = udp_cgroup_read,
+               .private = RES_LIMIT,
+       },
+       {
+               .name = "kmem.udp.usage_in_bytes",
+               .read_u64 = udp_cgroup_read,
+               .private = RES_USAGE,
+       },
+       {
+               .name = "kmem.udp.failcnt",
+               .private = RES_FAILCNT,
+               .trigger = udp_cgroup_reset,
+               .read_u64 = udp_cgroup_read,
+       },
+       {
+               .name = "kmem.udp.max_usage_in_bytes",
+               .private = RES_MAX_USAGE,
+               .trigger = udp_cgroup_reset,
+               .read_u64 = udp_cgroup_read,
+       },
+       { }     /* terminate */
+};
+
+static int __init udp_memcontrol_init(void)
+{
+       WARN_ON(cgroup_add_cftypes(&mem_cgroup_subsys, udp_files));
+       return 0;
+}
+__initcall(udp_memcontrol_init);
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 83b0a99..17d7df7 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -52,6 +52,7 @@
 #include <linux/seq_file.h>
 #include <trace/events/skb.h>
 #include "udp_impl.h"
+#include <net/udp_memcontrol.h>
 
 int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2)
 {
@@ -1420,6 +1421,7 @@ struct proto udpv6_prot = {
        .connect           = ip6_datagram_connect,
        .disconnect        = udp_disconnect,
        .ioctl             = udp_ioctl,
+       .init              = udp_init_sock,
        .destroy           = udpv6_destroy_sock,
        .setsockopt        = udpv6_setsockopt,
        .getsockopt        = udpv6_getsockopt,
@@ -1442,6 +1444,9 @@ struct proto udpv6_prot = {
        .compat_getsockopt = compat_udpv6_getsockopt,
 #endif
        .clear_sk          = udp_v6_clear_sk,
+#ifdef CONFIG_MEMCG_KMEM
+       .proto_cgroup           = udp_proto_cgroup,
+#endif
 };
 
 static struct inet_protosw udpv6_protosw = {
-- 
1.8.3.1


_______________________________________________
Devel mailing list
Devel@openvz.org
https://lists.openvz.org/mailman/listinfo/devel

Reply via email to