On 5/31/2023 4:05 AM, GONG, Ruiqi wrote:
> As the security infrastructure has taken over the management of multiple
> *_security blobs that are accessed by multiple security modules, and
> sk->sk_security shares the same situation, move its management out of
> individual security modules and into the security infrastructure as
> well. The infrastructure does the memory allocation, and each relavant
> module uses its own share.

Do you have a reason to make this change? The LSM infrastructure
manages other security blobs to enable multiple concurrently active
LSMs to use the blob. If only one LSM on a system can use the
socket blob there's no reason to move the management.

>
> Signed-off-by: GONG, Ruiqi <[email protected]>
> ---
>  include/linux/lsm_hooks.h         |  1 +
>  security/apparmor/include/net.h   |  2 +-
>  security/apparmor/lsm.c           | 20 +-------
>  security/security.c               | 35 ++++++++++++-
>  security/selinux/hooks.c          | 81 ++++++++++++++-----------------
>  security/selinux/include/objsec.h |  4 ++
>  security/selinux/netlabel.c       | 22 ++++-----
>  security/smack/smack.h            |  5 ++
>  security/smack/smack_lsm.c        | 65 +++++++++++--------------
>  security/smack/smack_netfilter.c  |  4 +-
>  10 files changed, 125 insertions(+), 114 deletions(-)
>
> diff --git a/include/linux/lsm_hooks.h b/include/linux/lsm_hooks.h
> index ab2b2fafa4a4..67b6e87ca6ec 100644
> --- a/include/linux/lsm_hooks.h
> +++ b/include/linux/lsm_hooks.h
> @@ -62,6 +62,7 @@ struct lsm_blob_sizes {
>       int     lbs_superblock;
>       int     lbs_ipc;
>       int     lbs_msg_msg;
> +     int     lbs_sock;
>       int     lbs_task;
>  };
>  
> diff --git a/security/apparmor/include/net.h b/security/apparmor/include/net.h
> index 6fa440b5daed..9eb159c09578 100644
> --- a/security/apparmor/include/net.h
> +++ b/security/apparmor/include/net.h
> @@ -51,7 +51,7 @@ struct aa_sk_ctx {
>       struct aa_label *peer;
>  };
>  
> -#define SK_CTX(X) ((X)->sk_security)
> +#define SK_CTX(X) ((X)->sk_security + apparmor_blob_sizes.lbs_sock)
>  #define SOCK_ctx(X) SOCK_INODE(X)->i_security
>  #define DEFINE_AUDIT_NET(NAME, OP, SK, F, T, P)                              
>   \
>       struct lsm_network_audit NAME ## _net = { .sk = (SK),             \
> diff --git a/security/apparmor/lsm.c b/security/apparmor/lsm.c
> index f431251ffb91..3dd849a6d7a1 100644
> --- a/security/apparmor/lsm.c
> +++ b/security/apparmor/lsm.c
> @@ -818,22 +818,6 @@ static int apparmor_task_kill(struct task_struct 
> *target, struct kernel_siginfo
>       return error;
>  }
>  
> -/**
> - * apparmor_sk_alloc_security - allocate and attach the sk_security field
> - */
> -static int apparmor_sk_alloc_security(struct sock *sk, int family, gfp_t 
> flags)
> -{
> -     struct aa_sk_ctx *ctx;
> -
> -     ctx = kzalloc(sizeof(*ctx), flags);
> -     if (!ctx)
> -             return -ENOMEM;
> -
> -     SK_CTX(sk) = ctx;
> -
> -     return 0;
> -}
> -
>  /**
>   * apparmor_sk_free_security - free the sk_security field
>   */
> @@ -841,10 +825,8 @@ static void apparmor_sk_free_security(struct sock *sk)
>  {
>       struct aa_sk_ctx *ctx = SK_CTX(sk);
>  
> -     SK_CTX(sk) = NULL;
>       aa_put_label(ctx->label);
>       aa_put_label(ctx->peer);
> -     kfree(ctx);
>  }
>  
>  /**
> @@ -1212,6 +1194,7 @@ static int apparmor_inet_conn_request(const struct sock 
> *sk, struct sk_buff *skb
>  struct lsm_blob_sizes apparmor_blob_sizes __ro_after_init = {
>       .lbs_cred = sizeof(struct aa_label *),
>       .lbs_file = sizeof(struct aa_file_ctx),
> +     .lbs_sock = sizeof(struct aa_sk_ctx),
>       .lbs_task = sizeof(struct aa_task_ctx),
>  };
>  
> @@ -1250,7 +1233,6 @@ static struct security_hook_list apparmor_hooks[] 
> __ro_after_init = {
>       LSM_HOOK_INIT(getprocattr, apparmor_getprocattr),
>       LSM_HOOK_INIT(setprocattr, apparmor_setprocattr),
>  
> -     LSM_HOOK_INIT(sk_alloc_security, apparmor_sk_alloc_security),
>       LSM_HOOK_INIT(sk_free_security, apparmor_sk_free_security),
>       LSM_HOOK_INIT(sk_clone_security, apparmor_sk_clone_security),
>  
> diff --git a/security/security.c b/security/security.c
> index b720424ca37d..e71f4717cde5 100644
> --- a/security/security.c
> +++ b/security/security.c
> @@ -30,6 +30,7 @@
>  #include <linux/string.h>
>  #include <linux/msg.h>
>  #include <net/flow.h>
> +#include <net/sock.h>
>  
>  #define MAX_LSM_EVM_XATTR    2
>  
> @@ -210,6 +211,7 @@ static void __init lsm_set_blob_sizes(struct 
> lsm_blob_sizes *needed)
>       lsm_set_blob_size(&needed->lbs_inode, &blob_sizes.lbs_inode);
>       lsm_set_blob_size(&needed->lbs_ipc, &blob_sizes.lbs_ipc);
>       lsm_set_blob_size(&needed->lbs_msg_msg, &blob_sizes.lbs_msg_msg);
> +     lsm_set_blob_size(&needed->lbs_sock, &blob_sizes.lbs_sock);
>       lsm_set_blob_size(&needed->lbs_superblock, &blob_sizes.lbs_superblock);
>       lsm_set_blob_size(&needed->lbs_task, &blob_sizes.lbs_task);
>  }
> @@ -376,6 +378,7 @@ static void __init ordered_lsm_init(void)
>       init_debug("inode blob size      = %d\n", blob_sizes.lbs_inode);
>       init_debug("ipc blob size        = %d\n", blob_sizes.lbs_ipc);
>       init_debug("msg_msg blob size    = %d\n", blob_sizes.lbs_msg_msg);
> +     init_debug("sock blob size       = %d\n", blob_sizes.lbs_sock);
>       init_debug("superblock blob size = %d\n", blob_sizes.lbs_superblock);
>       init_debug("task blob size       = %d\n", blob_sizes.lbs_task);
>  
> @@ -733,6 +736,27 @@ static int lsm_superblock_alloc(struct super_block *sb)
>       return 0;
>  }
>  
> +/**
> + * lsm_sock_alloc - allocate a composite socket blob
> + * @sk: the socket that needs a blob
> + *
> + * Allocate the socket blob for all the modules
> + *
> + * Returns 0, or -ENOMEM if memory can't be allocated.
> + */
> +static int lsm_sock_alloc(struct sock *sk)
> +{
> +     if (blob_sizes.lbs_sock == 0) {
> +             sk->sk_security = NULL;
> +             return 0;
> +     }
> +
> +     sk->sk_security = kzalloc(blob_sizes.lbs_sock, GFP_KERNEL);
> +     if (sk->sk_security == NULL)
> +             return -ENOMEM;
> +     return 0;
> +}
> +
>  /*
>   * The default value of the LSM hook is defined in linux/lsm_hook_defs.h and
>   * can be accessed with:
> @@ -4369,7 +4393,14 @@ EXPORT_SYMBOL(security_socket_getpeersec_dgram);
>   */
>  int security_sk_alloc(struct sock *sk, int family, gfp_t priority)
>  {
> -     return call_int_hook(sk_alloc_security, 0, sk, family, priority);
> +     int rc = lsm_sock_alloc(sk);
> +
> +     if (unlikely(rc))
> +             return rc;
> +     rc = call_int_hook(sk_alloc_security, 0, sk, family, priority);
> +     if (unlikely(rc))
> +             security_sk_free(sk);
> +     return rc;
>  }
>  
>  /**
> @@ -4381,6 +4412,8 @@ int security_sk_alloc(struct sock *sk, int family, 
> gfp_t priority)
>  void security_sk_free(struct sock *sk)
>  {
>       call_void_hook(sk_free_security, sk);
> +     kfree(sk->sk_security);
> +     sk->sk_security = NULL;
>  }
>  
>  /**
> diff --git a/security/selinux/hooks.c b/security/selinux/hooks.c
> index d06e350fedee..f8397f05dc90 100644
> --- a/security/selinux/hooks.c
> +++ b/security/selinux/hooks.c
> @@ -4497,7 +4497,7 @@ static int socket_sockcreate_sid(const struct 
> task_security_struct *tsec,
>  
>  static int sock_has_perm(struct sock *sk, u32 perms)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       struct common_audit_data ad;
>       struct lsm_network_audit net = {0,};
>  
> @@ -4552,7 +4552,7 @@ static int selinux_socket_post_create(struct socket 
> *sock, int family,
>       isec->initialized = LABEL_INITIALIZED;
>  
>       if (sock->sk) {
> -             sksec = sock->sk->sk_security;
> +             sksec = selinux_sock(sock->sk);
>               sksec->sclass = sclass;
>               sksec->sid = sid;
>               /* Allows detection of the first association on this socket */
> @@ -4568,8 +4568,8 @@ static int selinux_socket_post_create(struct socket 
> *sock, int family,
>  static int selinux_socket_socketpair(struct socket *socka,
>                                    struct socket *sockb)
>  {
> -     struct sk_security_struct *sksec_a = socka->sk->sk_security;
> -     struct sk_security_struct *sksec_b = sockb->sk->sk_security;
> +     struct sk_security_struct *sksec_a = selinux_sock(socka->sk);
> +     struct sk_security_struct *sksec_b = selinux_sock(sockb->sk);
>  
>       sksec_a->peer_sid = sksec_b->sid;
>       sksec_b->peer_sid = sksec_a->sid;
> @@ -4584,7 +4584,7 @@ static int selinux_socket_socketpair(struct socket 
> *socka,
>  static int selinux_socket_bind(struct socket *sock, struct sockaddr 
> *address, int addrlen)
>  {
>       struct sock *sk = sock->sk;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       u16 family;
>       int err;
>  
> @@ -4717,7 +4717,7 @@ static int selinux_socket_connect_helper(struct socket 
> *sock,
>                                        struct sockaddr *address, int addrlen)
>  {
>       struct sock *sk = sock->sk;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       int err;
>  
>       err = sock_has_perm(sk, SOCKET__CONNECT);
> @@ -4895,9 +4895,9 @@ static int selinux_socket_unix_stream_connect(struct 
> sock *sock,
>                                             struct sock *other,
>                                             struct sock *newsk)
>  {
> -     struct sk_security_struct *sksec_sock = sock->sk_security;
> -     struct sk_security_struct *sksec_other = other->sk_security;
> -     struct sk_security_struct *sksec_new = newsk->sk_security;
> +     struct sk_security_struct *sksec_sock = selinux_sock(sock);
> +     struct sk_security_struct *sksec_other = selinux_sock(other);
> +     struct sk_security_struct *sksec_new = selinux_sock(newsk);
>       struct common_audit_data ad;
>       struct lsm_network_audit net = {0,};
>       int err;
> @@ -4928,8 +4928,8 @@ static int selinux_socket_unix_stream_connect(struct 
> sock *sock,
>  static int selinux_socket_unix_may_send(struct socket *sock,
>                                       struct socket *other)
>  {
> -     struct sk_security_struct *ssec = sock->sk->sk_security;
> -     struct sk_security_struct *osec = other->sk->sk_security;
> +     struct sk_security_struct *ssec = selinux_sock(sock->sk);
> +     struct sk_security_struct *osec = selinux_sock(other->sk);
>       struct common_audit_data ad;
>       struct lsm_network_audit net = {0,};
>  
> @@ -4968,7 +4968,7 @@ static int selinux_sock_rcv_skb_compat(struct sock *sk, 
> struct sk_buff *skb,
>                                      u16 family)
>  {
>       int err = 0;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       u32 sk_sid = sksec->sid;
>       struct common_audit_data ad;
>       struct lsm_network_audit net = {0,};
> @@ -5000,7 +5000,7 @@ static int selinux_sock_rcv_skb_compat(struct sock *sk, 
> struct sk_buff *skb,
>  static int selinux_socket_sock_rcv_skb(struct sock *sk, struct sk_buff *skb)
>  {
>       int err;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       u16 family = sk->sk_family;
>       u32 sk_sid = sksec->sid;
>       struct common_audit_data ad;
> @@ -5073,7 +5073,7 @@ static int selinux_socket_getpeersec_stream(struct 
> socket *sock,
>       int err = 0;
>       char *scontext = NULL;
>       u32 scontext_len;
> -     struct sk_security_struct *sksec = sock->sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sock->sk);
>       u32 peer_sid = SECSID_NULL;
>  
>       if (sksec->sclass == SECCLASS_UNIX_STREAM_SOCKET ||
> @@ -5131,34 +5131,27 @@ static int selinux_socket_getpeersec_dgram(struct 
> socket *sock, struct sk_buff *
>  
>  static int selinux_sk_alloc_security(struct sock *sk, int family, gfp_t 
> priority)
>  {
> -     struct sk_security_struct *sksec;
> -
> -     sksec = kzalloc(sizeof(*sksec), priority);
> -     if (!sksec)
> -             return -ENOMEM;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>  
>       sksec->peer_sid = SECINITSID_UNLABELED;
>       sksec->sid = SECINITSID_UNLABELED;
>       sksec->sclass = SECCLASS_SOCKET;
>       selinux_netlbl_sk_security_reset(sksec);
> -     sk->sk_security = sksec;
>  
>       return 0;
>  }
>  
>  static void selinux_sk_free_security(struct sock *sk)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>  
> -     sk->sk_security = NULL;
>       selinux_netlbl_sk_security_free(sksec);
> -     kfree(sksec);
>  }
>  
>  static void selinux_sk_clone_security(const struct sock *sk, struct sock 
> *newsk)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> -     struct sk_security_struct *newsksec = newsk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
> +     struct sk_security_struct *newsksec = selinux_sock(newsk);
>  
>       newsksec->sid = sksec->sid;
>       newsksec->peer_sid = sksec->peer_sid;
> @@ -5172,7 +5165,7 @@ static void selinux_sk_getsecid(struct sock *sk, u32 
> *secid)
>       if (!sk)
>               *secid = SECINITSID_ANY_SOCKET;
>       else {
> -             struct sk_security_struct *sksec = sk->sk_security;
> +             struct sk_security_struct *sksec = selinux_sock(sk);
>  
>               *secid = sksec->sid;
>       }
> @@ -5182,7 +5175,7 @@ static void selinux_sock_graft(struct sock *sk, struct 
> socket *parent)
>  {
>       struct inode_security_struct *isec =
>               inode_security_novalidate(SOCK_INODE(parent));
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>  
>       if (sk->sk_family == PF_INET || sk->sk_family == PF_INET6 ||
>           sk->sk_family == PF_UNIX)
> @@ -5199,7 +5192,7 @@ static int selinux_sctp_process_new_assoc(struct 
> sctp_association *asoc,
>  {
>       struct sock *sk = asoc->base.sk;
>       u16 family = sk->sk_family;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       struct common_audit_data ad;
>       struct lsm_network_audit net = {0,};
>       int err;
> @@ -5256,7 +5249,7 @@ static int selinux_sctp_process_new_assoc(struct 
> sctp_association *asoc,
>  static int selinux_sctp_assoc_request(struct sctp_association *asoc,
>                                     struct sk_buff *skb)
>  {
> -     struct sk_security_struct *sksec = asoc->base.sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>       u32 conn_sid;
>       int err;
>  
> @@ -5289,7 +5282,7 @@ static int selinux_sctp_assoc_request(struct 
> sctp_association *asoc,
>  static int selinux_sctp_assoc_established(struct sctp_association *asoc,
>                                         struct sk_buff *skb)
>  {
> -     struct sk_security_struct *sksec = asoc->base.sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>  
>       if (!selinux_policycap_extsockclass())
>               return 0;
> @@ -5388,8 +5381,8 @@ static int selinux_sctp_bind_connect(struct sock *sk, 
> int optname,
>  static void selinux_sctp_sk_clone(struct sctp_association *asoc, struct sock 
> *sk,
>                                 struct sock *newsk)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> -     struct sk_security_struct *newsksec = newsk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
> +     struct sk_security_struct *newsksec = selinux_sock(newsk);
>  
>       /* If policy does not support SECCLASS_SCTP_SOCKET then call
>        * the non-sctp clone version.
> @@ -5405,8 +5398,8 @@ static void selinux_sctp_sk_clone(struct 
> sctp_association *asoc, struct sock *sk
>  
>  static int selinux_mptcp_add_subflow(struct sock *sk, struct sock *ssk)
>  {
> -     struct sk_security_struct *ssksec = ssk->sk_security;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *ssksec = selinux_sock(ssk);
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>  
>       ssksec->sclass = sksec->sclass;
>       ssksec->sid = sksec->sid;
> @@ -5421,7 +5414,7 @@ static int selinux_mptcp_add_subflow(struct sock *sk, 
> struct sock *ssk)
>  static int selinux_inet_conn_request(const struct sock *sk, struct sk_buff 
> *skb,
>                                    struct request_sock *req)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       int err;
>       u16 family = req->rsk_ops->family;
>       u32 connsid;
> @@ -5442,7 +5435,7 @@ static int selinux_inet_conn_request(const struct sock 
> *sk, struct sk_buff *skb,
>  static void selinux_inet_csk_clone(struct sock *newsk,
>                                  const struct request_sock *req)
>  {
> -     struct sk_security_struct *newsksec = newsk->sk_security;
> +     struct sk_security_struct *newsksec = selinux_sock(newsk);
>  
>       newsksec->sid = req->secid;
>       newsksec->peer_sid = req->peer_secid;
> @@ -5459,7 +5452,7 @@ static void selinux_inet_csk_clone(struct sock *newsk,
>  static void selinux_inet_conn_established(struct sock *sk, struct sk_buff 
> *skb)
>  {
>       u16 family = sk->sk_family;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>  
>       /* handle mapped IPv4 packets arriving via IPv6 sockets */
>       if (family == PF_INET6 && skb->protocol == htons(ETH_P_IP))
> @@ -5540,7 +5533,7 @@ static int selinux_tun_dev_attach_queue(void *security)
>  static int selinux_tun_dev_attach(struct sock *sk, void *security)
>  {
>       struct tun_security_struct *tunsec = security;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>  
>       /* we don't currently perform any NetLabel based labeling here and it
>        * isn't clear that we would want to do so anyway; while we could apply
> @@ -5666,7 +5659,7 @@ static unsigned int selinux_ip_output(void *priv, 
> struct sk_buff *skb,
>                       return NF_ACCEPT;
>  
>               /* standard practice, label using the parent socket */
> -             sksec = sk->sk_security;
> +             sksec = selinux_sock(sk);
>               sid = sksec->sid;
>       } else
>               sid = SECINITSID_KERNEL;
> @@ -5689,7 +5682,7 @@ static unsigned int selinux_ip_postroute_compat(struct 
> sk_buff *skb,
>       sk = skb_to_full_sk(skb);
>       if (sk == NULL)
>               return NF_ACCEPT;
> -     sksec = sk->sk_security;
> +     sksec = selinux_sock(sk);
>  
>       ad.type = LSM_AUDIT_DATA_NET;
>       ad.u.net = &net;
> @@ -5779,9 +5772,8 @@ static unsigned int selinux_ip_postroute(void *priv,
>                * selinux_inet_conn_request().  See also selinux_ip_output()
>                * for similar problems. */
>               u32 skb_sid;
> -             struct sk_security_struct *sksec;
> +             struct sk_security_struct *sksec = selinux_sock(sk);
>  
> -             sksec = sk->sk_security;
>               if (selinux_skb_peerlbl_sid(skb, family, &skb_sid))
>                       return NF_DROP;
>               /* At this point, if the returned skb peerlbl is SECSID_NULL
> @@ -5810,7 +5802,7 @@ static unsigned int selinux_ip_postroute(void *priv,
>       } else {
>               /* Locally generated packet, fetch the security label from the
>                * associated socket. */
> -             struct sk_security_struct *sksec = sk->sk_security;
> +             struct sk_security_struct *sksec = selinux_sock(sk);
>               peer_sid = sksec->sid;
>               secmark_perm = PACKET__SEND;
>       }
> @@ -5856,7 +5848,7 @@ static int selinux_netlink_send(struct sock *sk, struct 
> sk_buff *skb)
>       unsigned int data_len = skb->len;
>       unsigned char *data = skb->data;
>       struct nlmsghdr *nlh;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       u16 sclass = sksec->sclass;
>       u32 perm;
>  
> @@ -6814,6 +6806,7 @@ struct lsm_blob_sizes selinux_blob_sizes 
> __ro_after_init = {
>       .lbs_inode = sizeof(struct inode_security_struct),
>       .lbs_ipc = sizeof(struct ipc_security_struct),
>       .lbs_msg_msg = sizeof(struct msg_security_struct),
> +     .lbs_sock = sizeof(struct sk_security_struct),
>       .lbs_superblock = sizeof(struct superblock_security_struct),
>  };
>  
> diff --git a/security/selinux/include/objsec.h 
> b/security/selinux/include/objsec.h
> index 2953132408bf..49221f441c68 100644
> --- a/security/selinux/include/objsec.h
> +++ b/security/selinux/include/objsec.h
> @@ -194,4 +194,8 @@ static inline struct superblock_security_struct 
> *selinux_superblock(
>       return superblock->s_security + selinux_blob_sizes.lbs_superblock;
>  }
>  
> +static inline struct sk_security_struct *selinux_sock(const struct sock *sk)
> +{
> +     return sk->sk_security + selinux_blob_sizes.lbs_sock;
> +}
>  #endif /* _SELINUX_OBJSEC_H_ */
> diff --git a/security/selinux/netlabel.c b/security/selinux/netlabel.c
> index 528f5186e912..9755561aa466 100644
> --- a/security/selinux/netlabel.c
> +++ b/security/selinux/netlabel.c
> @@ -68,7 +68,7 @@ static int selinux_netlbl_sidlookup_cached(struct sk_buff 
> *skb,
>  static struct netlbl_lsm_secattr *selinux_netlbl_sock_genattr(struct sock 
> *sk)
>  {
>       int rc;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       struct netlbl_lsm_secattr *secattr;
>  
>       if (sksec->nlbl_secattr != NULL)
> @@ -100,7 +100,7 @@ static struct netlbl_lsm_secattr 
> *selinux_netlbl_sock_getattr(
>                                                       const struct sock *sk,
>                                                       u32 sid)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       struct netlbl_lsm_secattr *secattr = sksec->nlbl_secattr;
>  
>       if (secattr == NULL)
> @@ -239,7 +239,7 @@ int selinux_netlbl_skbuff_setsid(struct sk_buff *skb,
>        * being labeled by it's parent socket, if it is just exit */
>       sk = skb_to_full_sk(skb);
>       if (sk != NULL) {
> -             struct sk_security_struct *sksec = sk->sk_security;
> +             struct sk_security_struct *sksec = selinux_sock(sk);
>  
>               if (sksec->nlbl_state != NLBL_REQSKB)
>                       return 0;
> @@ -276,7 +276,7 @@ int selinux_netlbl_sctp_assoc_request(struct 
> sctp_association *asoc,
>  {
>       int rc;
>       struct netlbl_lsm_secattr secattr;
> -     struct sk_security_struct *sksec = asoc->base.sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>       struct sockaddr_in addr4;
>       struct sockaddr_in6 addr6;
>  
> @@ -355,7 +355,7 @@ int selinux_netlbl_inet_conn_request(struct request_sock 
> *req, u16 family)
>   */
>  void selinux_netlbl_inet_csk_clone(struct sock *sk, u16 family)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>  
>       if (family == PF_INET)
>               sksec->nlbl_state = NLBL_LABELED;
> @@ -373,8 +373,8 @@ void selinux_netlbl_inet_csk_clone(struct sock *sk, u16 
> family)
>   */
>  void selinux_netlbl_sctp_sk_clone(struct sock *sk, struct sock *newsk)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> -     struct sk_security_struct *newsksec = newsk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
> +     struct sk_security_struct *newsksec = selinux_sock(newsk);
>  
>       newsksec->nlbl_state = sksec->nlbl_state;
>  }
> @@ -392,7 +392,7 @@ void selinux_netlbl_sctp_sk_clone(struct sock *sk, struct 
> sock *newsk)
>  int selinux_netlbl_socket_post_create(struct sock *sk, u16 family)
>  {
>       int rc;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       struct netlbl_lsm_secattr *secattr;
>  
>       if (family != PF_INET && family != PF_INET6)
> @@ -506,7 +506,7 @@ int selinux_netlbl_socket_setsockopt(struct socket *sock,
>  {
>       int rc = 0;
>       struct sock *sk = sock->sk;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       struct netlbl_lsm_secattr secattr;
>  
>       if (selinux_netlbl_option(level, optname) &&
> @@ -544,7 +544,7 @@ static int selinux_netlbl_socket_connect_helper(struct 
> sock *sk,
>                                               struct sockaddr *addr)
>  {
>       int rc;
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>       struct netlbl_lsm_secattr *secattr;
>  
>       /* connected sockets are allowed to disconnect when the address family
> @@ -583,7 +583,7 @@ static int selinux_netlbl_socket_connect_helper(struct 
> sock *sk,
>  int selinux_netlbl_socket_connect_locked(struct sock *sk,
>                                        struct sockaddr *addr)
>  {
> -     struct sk_security_struct *sksec = sk->sk_security;
> +     struct sk_security_struct *sksec = selinux_sock(sk);
>  
>       if (sksec->nlbl_state != NLBL_REQSKB &&
>           sksec->nlbl_state != NLBL_CONNLABELED)
> diff --git a/security/smack/smack.h b/security/smack/smack.h
> index aa15ff56ed6e..2d0163076eca 100644
> --- a/security/smack/smack.h
> +++ b/security/smack/smack.h
> @@ -355,6 +355,11 @@ static inline struct superblock_smack *smack_superblock(
>       return superblock->s_security + smack_blob_sizes.lbs_superblock;
>  }
>  
> +static inline struct socket_smack *smack_sock(const struct sock *sk)
> +{
> +     return sk->sk_security + smack_blob_sizes.lbs_sock;
> +}
> +
>  /*
>   * Is the directory transmuting?
>   */
> diff --git a/security/smack/smack_lsm.c b/security/smack/smack_lsm.c
> index 6e270cf3fd30..ab026ff79504 100644
> --- a/security/smack/smack_lsm.c
> +++ b/security/smack/smack_lsm.c
> @@ -1502,7 +1502,7 @@ static int smack_inode_getsecurity(struct mnt_idmap 
> *idmap,
>               if (sock == NULL || sock->sk == NULL)
>                       return -EOPNOTSUPP;
>  
> -             ssp = sock->sk->sk_security;
> +             ssp = smack_sock(sock->sk);
>  
>               if (strcmp(name, XATTR_SMACK_IPIN) == 0)
>                       isp = ssp->smk_in;
> @@ -1890,7 +1890,7 @@ static int smack_file_receive(struct file *file)
>  
>       if (inode->i_sb->s_magic == SOCKFS_MAGIC) {
>               sock = SOCKET_I(inode);
> -             ssp = sock->sk->sk_security;
> +             ssp = smack_sock(sock->sk);
>               tsp = smack_cred(current_cred());
>               /*
>                * If the receiving process can't write to the
> @@ -2310,11 +2310,7 @@ static void smack_task_to_inode(struct task_struct *p, 
> struct inode *inode)
>  static int smack_sk_alloc_security(struct sock *sk, int family, gfp_t 
> gfp_flags)
>  {
>       struct smack_known *skp = smk_of_current();
> -     struct socket_smack *ssp;
> -
> -     ssp = kzalloc(sizeof(struct socket_smack), gfp_flags);
> -     if (ssp == NULL)
> -             return -ENOMEM;
> +     struct socket_smack *ssp = smack_sock(sk);
>  
>       /*
>        * Sockets created by kernel threads receive web label.
> @@ -2328,8 +2324,6 @@ static int smack_sk_alloc_security(struct sock *sk, int 
> family, gfp_t gfp_flags)
>       }
>       ssp->smk_packet = NULL;
>  
> -     sk->sk_security = ssp;
> -
>       return 0;
>  }
>  
> @@ -2355,7 +2349,6 @@ static void smack_sk_free_security(struct sock *sk)
>               rcu_read_unlock();
>       }
>  #endif
> -     kfree(sk->sk_security);
>  }
>  
>  /**
> @@ -2367,8 +2360,8 @@ static void smack_sk_free_security(struct sock *sk)
>   */
>  static void smack_sk_clone_security(const struct sock *sk, struct sock 
> *newsk)
>  {
> -     struct socket_smack *ssp_old = sk->sk_security;
> -     struct socket_smack *ssp_new = newsk->sk_security;
> +     struct socket_smack *ssp_old = smack_sock(sk);
> +     struct socket_smack *ssp_new = smack_sock(newsk);
>  
>       *ssp_new = *ssp_old;
>  }
> @@ -2484,7 +2477,7 @@ static struct smack_known *smack_ipv6host_label(struct 
> sockaddr_in6 *sip)
>   */
>  static int smack_netlbl_add(struct sock *sk)
>  {
> -     struct socket_smack *ssp = sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sk);
>       struct smack_known *skp = ssp->smk_out;
>       int rc;
>  
> @@ -2516,7 +2509,7 @@ static int smack_netlbl_add(struct sock *sk)
>   */
>  static void smack_netlbl_delete(struct sock *sk)
>  {
> -     struct socket_smack *ssp = sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sk);
>  
>       /*
>        * Take the label off the socket if one is set.
> @@ -2548,7 +2541,7 @@ static int smk_ipv4_check(struct sock *sk, struct 
> sockaddr_in *sap)
>       struct smack_known *skp;
>       int rc = 0;
>       struct smack_known *hkp;
> -     struct socket_smack *ssp = sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sk);
>       struct smk_audit_info ad;
>  
>       rcu_read_lock();
> @@ -2621,7 +2614,7 @@ static void smk_ipv6_port_label(struct socket *sock, 
> struct sockaddr *address)
>  {
>       struct sock *sk = sock->sk;
>       struct sockaddr_in6 *addr6;
> -     struct socket_smack *ssp = sock->sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sock->sk);
>       struct smk_port_label *spp;
>       unsigned short port = 0;
>  
> @@ -2709,7 +2702,7 @@ static int smk_ipv6_port_check(struct sock *sk, struct 
> sockaddr_in6 *address,
>                               int act)
>  {
>       struct smk_port_label *spp;
> -     struct socket_smack *ssp = sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sk);
>       struct smack_known *skp = NULL;
>       unsigned short port;
>       struct smack_known *object;
> @@ -2803,7 +2796,7 @@ static int smack_inode_setsecurity(struct inode *inode, 
> const char *name,
>       if (sock == NULL || sock->sk == NULL)
>               return -EOPNOTSUPP;
>  
> -     ssp = sock->sk->sk_security;
> +     ssp = smack_sock(sock->sk);
>  
>       if (strcmp(name, XATTR_SMACK_IPIN) == 0)
>               ssp->smk_in = skp;
> @@ -2851,7 +2844,7 @@ static int smack_socket_post_create(struct socket 
> *sock, int family,
>        * Sockets created by kernel threads receive web label.
>        */
>       if (unlikely(current->flags & PF_KTHREAD)) {
> -             ssp = sock->sk->sk_security;
> +             ssp = smack_sock(sock->sk);
>               ssp->smk_in = &smack_known_web;
>               ssp->smk_out = &smack_known_web;
>       }
> @@ -2876,8 +2869,8 @@ static int smack_socket_post_create(struct socket 
> *sock, int family,
>  static int smack_socket_socketpair(struct socket *socka,
>                                  struct socket *sockb)
>  {
> -     struct socket_smack *asp = socka->sk->sk_security;
> -     struct socket_smack *bsp = sockb->sk->sk_security;
> +     struct socket_smack *asp = smack_sock(socka->sk);
> +     struct socket_smack *bsp = smack_sock(sockb->sk);
>  
>       asp->smk_packet = bsp->smk_out;
>       bsp->smk_packet = asp->smk_out;
> @@ -2940,7 +2933,7 @@ static int smack_socket_connect(struct socket *sock, 
> struct sockaddr *sap,
>               if (__is_defined(SMACK_IPV6_SECMARK_LABELING))
>                       rsp = smack_ipv6host_label(sip);
>               if (rsp != NULL) {
> -                     struct socket_smack *ssp = sock->sk->sk_security;
> +                     struct socket_smack *ssp = smack_sock(sock->sk);
>  
>                       rc = smk_ipv6_check(ssp->smk_out, rsp, sip,
>                                           SMK_CONNECTING);
> @@ -3671,9 +3664,9 @@ static int smack_unix_stream_connect(struct sock *sock,
>  {
>       struct smack_known *skp;
>       struct smack_known *okp;
> -     struct socket_smack *ssp = sock->sk_security;
> -     struct socket_smack *osp = other->sk_security;
> -     struct socket_smack *nsp = newsk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sock);
> +     struct socket_smack *osp = smack_sock(other);
> +     struct socket_smack *nsp = smack_sock(newsk);
>       struct smk_audit_info ad;
>       int rc = 0;
>  #ifdef CONFIG_AUDIT
> @@ -3719,8 +3712,8 @@ static int smack_unix_stream_connect(struct sock *sock,
>   */
>  static int smack_unix_may_send(struct socket *sock, struct socket *other)
>  {
> -     struct socket_smack *ssp = sock->sk->sk_security;
> -     struct socket_smack *osp = other->sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sock->sk);
> +     struct socket_smack *osp = smack_sock(other->sk);
>       struct smk_audit_info ad;
>       int rc;
>  
> @@ -3757,7 +3750,7 @@ static int smack_socket_sendmsg(struct socket *sock, 
> struct msghdr *msg,
>       struct sockaddr_in6 *sap = (struct sockaddr_in6 *) msg->msg_name;
>  #endif
>  #ifdef SMACK_IPV6_SECMARK_LABELING
> -     struct socket_smack *ssp = sock->sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sock->sk);
>       struct smack_known *rsp;
>  #endif
>       int rc = 0;
> @@ -3969,7 +3962,7 @@ static struct smack_known *smack_from_netlbl(const 
> struct sock *sk, u16 family,
>       netlbl_secattr_init(&secattr);
>  
>       if (sk)
> -             ssp = sk->sk_security;
> +             ssp = smack_sock(sk);
>  
>       if (netlbl_skbuff_getattr(skb, family, &secattr) == 0) {
>               skp = smack_from_secattr(&secattr, ssp);
> @@ -3991,7 +3984,7 @@ static struct smack_known *smack_from_netlbl(const 
> struct sock *sk, u16 family,
>   */
>  static int smack_socket_sock_rcv_skb(struct sock *sk, struct sk_buff *skb)
>  {
> -     struct socket_smack *ssp = sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sk);
>       struct smack_known *skp = NULL;
>       int rc = 0;
>       struct smk_audit_info ad;
> @@ -4090,12 +4083,11 @@ static int smack_socket_getpeersec_stream(struct 
> socket *sock,
>                                         sockptr_t optval, sockptr_t optlen,
>                                         unsigned int len)
>  {
> -     struct socket_smack *ssp;
> +     struct socket_smack *ssp = smack_sock(sock->sk);
>       char *rcp = "";
>       u32 slen = 1;
>       int rc = 0;
>  
> -     ssp = sock->sk->sk_security;
>       if (ssp->smk_packet != NULL) {
>               rcp = ssp->smk_packet->smk_known;
>               slen = strlen(rcp) + 1;
> @@ -4145,7 +4137,7 @@ static int smack_socket_getpeersec_dgram(struct socket 
> *sock,
>  
>       switch (family) {
>       case PF_UNIX:
> -             ssp = sock->sk->sk_security;
> +             ssp = smack_sock(sock->sk);
>               s = ssp->smk_out->smk_secid;
>               break;
>       case PF_INET:
> @@ -4194,7 +4186,7 @@ static void smack_sock_graft(struct sock *sk, struct 
> socket *parent)
>           (sk->sk_family != PF_INET && sk->sk_family != PF_INET6))
>               return;
>  
> -     ssp = sk->sk_security;
> +     ssp = smack_sock(sk);
>       ssp->smk_in = skp;
>       ssp->smk_out = skp;
>       /* cssp->smk_packet is already set in smack_inet_csk_clone() */
> @@ -4214,7 +4206,7 @@ static int smack_inet_conn_request(const struct sock 
> *sk, struct sk_buff *skb,
>  {
>       u16 family = sk->sk_family;
>       struct smack_known *skp;
> -     struct socket_smack *ssp = sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sk);
>       struct sockaddr_in addr;
>       struct iphdr *hdr;
>       struct smack_known *hskp;
> @@ -4300,7 +4292,7 @@ static int smack_inet_conn_request(const struct sock 
> *sk, struct sk_buff *skb,
>  static void smack_inet_csk_clone(struct sock *sk,
>                                const struct request_sock *req)
>  {
> -     struct socket_smack *ssp = sk->sk_security;
> +     struct socket_smack *ssp = smack_sock(sk);
>       struct smack_known *skp;
>  
>       if (req->peer_secid != 0) {
> @@ -4868,6 +4860,7 @@ struct lsm_blob_sizes smack_blob_sizes __ro_after_init 
> = {
>       .lbs_inode = sizeof(struct inode_smack),
>       .lbs_ipc = sizeof(struct smack_known *),
>       .lbs_msg_msg = sizeof(struct smack_known *),
> +     .lbs_sock = sizeof(struct socket_smack),
>       .lbs_superblock = sizeof(struct superblock_smack),
>  };
>  
> diff --git a/security/smack/smack_netfilter.c 
> b/security/smack/smack_netfilter.c
> index b945c1d3a743..bad71b7e648d 100644
> --- a/security/smack/smack_netfilter.c
> +++ b/security/smack/smack_netfilter.c
> @@ -26,8 +26,8 @@ static unsigned int smack_ip_output(void *priv,
>       struct socket_smack *ssp;
>       struct smack_known *skp;
>  
> -     if (sk && sk->sk_security) {
> -             ssp = sk->sk_security;
> +     if (sk) {
> +             ssp = smack_sock(sk);
>               skp = ssp->smk_out;
>               skb->secmark = skp->smk_secid;
>       }

Reply via email to