From: Johannes Berg <[email protected]>

Add the ability to have an arbitrary validation function attached
to NLA_BINARY.

NOTE: obviously, the nl80211 part should be split off, that's just
      to show how it might be used - saves ~1.2KiB there (x86-64)

NOTE: should probably allow this for the exact length fields, and
      perhaps even provide a general NLA_POLICY_VALID_ETH_ADDR define
      and similar?
---
 include/net/netlink.h  |  6 +++++
 lib/nlattr.c           |  6 +++++
 net/wireless/nl80211.c | 67 ++++++++++++++------------------------------------
 3 files changed, 31 insertions(+), 48 deletions(-)

diff --git a/include/net/netlink.h b/include/net/netlink.h
index ddabc832febc..ff80ccbfc6b0 100644
--- a/include/net/netlink.h
+++ b/include/net/netlink.h
@@ -256,6 +256,10 @@ enum {
  *                         of s16 - do that as usual in the code instead.
  *    All other            Unused - but note that it's a union
  *
+ * Meaning of `validate' field:
+ *    NLA_BINARY           Validation function called for the binary attribute
+ *    All other            Unused - but note that it's a union
+ *
  * Example:
  * static const struct nla_policy my_policy[ATTR_MAX+1] = {
  *     [ATTR_FOO] = { .type = NLA_U16 },
@@ -272,6 +276,8 @@ struct nla_policy {
                struct {
                        s16 min, max;
                };
+               int (*validate)(const struct nlattr *attr,
+                               struct netlink_ext_ack *extack);
        };
 };
 
diff --git a/lib/nlattr.c b/lib/nlattr.c
index dd8d34c1ae19..4f003da28918 100644
--- a/lib/nlattr.c
+++ b/lib/nlattr.c
@@ -175,6 +175,12 @@ static int validate_nla(const struct nlattr *nla, int 
maxtype,
        case NLA_BINARY:
                if (pt->len && attrlen > pt->len)
                        goto out_err;
+
+               if (pt->len && pt->validate) {
+                       err = pt->validate(nla, extack);
+                       if (err)
+                               return err;
+               }
                break;
 
        case NLA_NESTED:
diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c
index 551310d5965e..a09066896f64 100644
--- a/net/wireless/nl80211.c
+++ b/net/wireless/nl80211.c
@@ -200,6 +200,9 @@ cfg80211_get_dev_from_info(struct net *netns, struct 
genl_info *info)
        return __cfg80211_rdev_from_attrs(netns, info->attrs);
 }
 
+static int validate_ie_attr(const struct nlattr *attr,
+                           struct netlink_ext_ack *extack);
+
 /* policy for the attributes */
 static const struct nla_policy
 nl80211_pmsr_ftm_req_attr_policy[NL80211_PMSR_FTM_REQ_ATTR_MAX + 1] = {
@@ -311,7 +314,8 @@ const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = {
        [NL80211_ATTR_BEACON_HEAD] = { .type = NLA_BINARY,
                                       .len = IEEE80211_MAX_DATA_LEN },
        [NL80211_ATTR_BEACON_TAIL] = { .type = NLA_BINARY,
-                                      .len = IEEE80211_MAX_DATA_LEN },
+                                      .len = IEEE80211_MAX_DATA_LEN,
+                                      .validate = validate_ie_attr },
        [NL80211_ATTR_STA_AID] = {
                .type = NLA_U16,
                .min = 1,
@@ -348,7 +352,8 @@ const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = {
 
        [NL80211_ATTR_MGMT_SUBTYPE] = { .type = NLA_U8 },
        [NL80211_ATTR_IE] = { .type = NLA_BINARY,
-                             .len = IEEE80211_MAX_DATA_LEN },
+                             .len = IEEE80211_MAX_DATA_LEN,
+                             .validate = validate_ie_attr },
        [NL80211_ATTR_SCAN_FREQUENCIES] = { .type = NLA_NESTED },
        [NL80211_ATTR_SCAN_SSIDS] = { .type = NLA_NESTED },
 
@@ -417,9 +422,11 @@ const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = 
{
                .max = NL80211_HIDDEN_SSID_ZERO_CONTENTS,
        },
        [NL80211_ATTR_IE_PROBE_RESP] = { .type = NLA_BINARY,
-                                        .len = IEEE80211_MAX_DATA_LEN },
+                                        .len = IEEE80211_MAX_DATA_LEN,
+                                        .validate = validate_ie_attr },
        [NL80211_ATTR_IE_ASSOC_RESP] = { .type = NLA_BINARY,
-                                        .len = IEEE80211_MAX_DATA_LEN },
+                                        .len = IEEE80211_MAX_DATA_LEN,
+                                        .validate = validate_ie_attr },
        [NL80211_ATTR_ROAM_SUPPORT] = { .type = NLA_FLAG },
        [NL80211_ATTR_SCHED_SCAN_MATCH] = { .type = NLA_NESTED },
        [NL80211_ATTR_TX_NO_CCK_RATE] = { .type = NLA_FLAG },
@@ -738,14 +745,12 @@ static int nl80211_prepare_wdev_dump(struct sk_buff *skb,
 }
 
 /* IE validation */
-static bool is_valid_ie_attr(const struct nlattr *attr)
+static int validate_ie_attr(const struct nlattr *attr,
+                           struct netlink_ext_ack *extack)
 {
        const u8 *pos;
        int len;
 
-       if (!attr)
-               return true;
-
        pos = nla_data(attr);
        len = nla_len(attr);
 
@@ -753,18 +758,18 @@ static bool is_valid_ie_attr(const struct nlattr *attr)
                u8 elemlen;
 
                if (len < 2)
-                       return false;
+                       return -EINVAL;
                len -= 2;
 
                elemlen = pos[1];
                if (elemlen > len)
-                       return false;
+                       return -EINVAL;
 
                len -= elemlen;
                pos += 2 + elemlen;
        }
 
-       return true;
+       return 0;
 }
 
 /* message building helper */
@@ -4191,12 +4196,6 @@ static int nl80211_parse_beacon(struct nlattr *attrs[],
 {
        bool haveinfo = false;
 
-       if (!is_valid_ie_attr(attrs[NL80211_ATTR_BEACON_TAIL]) ||
-           !is_valid_ie_attr(attrs[NL80211_ATTR_IE]) ||
-           !is_valid_ie_attr(attrs[NL80211_ATTR_IE_PROBE_RESP]) ||
-           !is_valid_ie_attr(attrs[NL80211_ATTR_IE_ASSOC_RESP]))
-               return -EINVAL;
-
        memset(bcn, 0, sizeof(*bcn));
 
        if (attrs[NL80211_ATTR_BEACON_HEAD]) {
@@ -6326,7 +6325,8 @@ static const struct nla_policy
        [NL80211_MESH_SETUP_AUTH_PROTOCOL] = { .type = NLA_U8 },
        [NL80211_MESH_SETUP_USERSPACE_MPM] = { .type = NLA_FLAG },
        [NL80211_MESH_SETUP_IE] = { .type = NLA_BINARY,
-                                   .len = IEEE80211_MAX_DATA_LEN },
+                                   .len = IEEE80211_MAX_DATA_LEN,
+                                   .validate = validate_ie_attr },
        [NL80211_MESH_SETUP_USERSPACE_AMPE] = { .type = NLA_FLAG },
 };
 
@@ -6527,8 +6527,6 @@ static int nl80211_parse_mesh_setup(struct genl_info 
*info,
        if (tb[NL80211_MESH_SETUP_IE]) {
                struct nlattr *ieattr =
                        tb[NL80211_MESH_SETUP_IE];
-               if (!is_valid_ie_attr(ieattr))
-                       return -EINVAL;
                setup->ie = nla_data(ieattr);
                setup->ie_len = nla_len(ieattr);
        }
@@ -7161,9 +7159,6 @@ static int nl80211_trigger_scan(struct sk_buff *skb, 
struct genl_info *info)
        int err, tmp, n_ssids = 0, n_channels, i;
        size_t ie_len;
 
-       if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
-               return -EINVAL;
-
        wiphy = &rdev->wiphy;
 
        if (wdev->iftype == NL80211_IFTYPE_NAN)
@@ -7517,9 +7512,6 @@ nl80211_parse_sched_scan(struct wiphy *wiphy, struct 
wireless_dev *wdev,
        struct nlattr *tb[NL80211_SCHED_SCAN_MATCH_ATTR_MAX + 1];
        s32 default_match_rssi = NL80211_SCAN_RSSI_THOLD_OFF;
 
-       if (!is_valid_ie_attr(attrs[NL80211_ATTR_IE]))
-               return ERR_PTR(-EINVAL);
-
        if (attrs[NL80211_ATTR_SCAN_FREQUENCIES]) {
                n_channels = validate_scan_freqs(
                                attrs[NL80211_ATTR_SCAN_FREQUENCIES]);
@@ -8487,9 +8479,6 @@ static int nl80211_authenticate(struct sk_buff *skb, 
struct genl_info *info)
        struct key_parse key;
        bool local_state_change;
 
-       if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
-               return -EINVAL;
-
        if (!info->attrs[NL80211_ATTR_MAC])
                return -EINVAL;
 
@@ -8728,9 +8717,6 @@ static int nl80211_associate(struct sk_buff *skb, struct 
genl_info *info)
            dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid)
                return -EPERM;
 
-       if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
-               return -EINVAL;
-
        if (!info->attrs[NL80211_ATTR_MAC] ||
            !info->attrs[NL80211_ATTR_SSID] ||
            !info->attrs[NL80211_ATTR_WIPHY_FREQ])
@@ -8854,9 +8840,6 @@ static int nl80211_deauthenticate(struct sk_buff *skb, 
struct genl_info *info)
            dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid)
                return -EPERM;
 
-       if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
-               return -EINVAL;
-
        if (!info->attrs[NL80211_ATTR_MAC])
                return -EINVAL;
 
@@ -8905,9 +8888,6 @@ static int nl80211_disassociate(struct sk_buff *skb, 
struct genl_info *info)
            dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid)
                return -EPERM;
 
-       if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
-               return -EINVAL;
-
        if (!info->attrs[NL80211_ATTR_MAC])
                return -EINVAL;
 
@@ -8982,9 +8962,6 @@ static int nl80211_join_ibss(struct sk_buff *skb, struct 
genl_info *info)
 
        memset(&ibss, 0, sizeof(ibss));
 
-       if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
-               return -EINVAL;
-
        if (!info->attrs[NL80211_ATTR_SSID] ||
            !nla_len(info->attrs[NL80211_ATTR_SSID]))
                return -EINVAL;
@@ -9422,9 +9399,6 @@ static int nl80211_connect(struct sk_buff *skb, struct 
genl_info *info)
 
        memset(&connect, 0, sizeof(connect));
 
-       if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
-               return -EINVAL;
-
        if (!info->attrs[NL80211_ATTR_SSID] ||
            !nla_len(info->attrs[NL80211_ATTR_SSID]))
                return -EINVAL;
@@ -9655,8 +9629,6 @@ static int nl80211_update_connect_params(struct sk_buff 
*skb,
                return -EOPNOTSUPP;
 
        if (info->attrs[NL80211_ATTR_IE]) {
-               if (!is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
-                       return -EINVAL;
                connect.ie = nla_data(info->attrs[NL80211_ATTR_IE]);
                connect.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]);
                changed |= UPDATE_ASSOC_IES;
@@ -12308,8 +12280,7 @@ static int nl80211_update_ft_ies(struct sk_buff *skb, 
struct genl_info *info)
                return -EOPNOTSUPP;
 
        if (!info->attrs[NL80211_ATTR_MDID] ||
-           !info->attrs[NL80211_ATTR_IE] ||
-           !is_valid_ie_attr(info->attrs[NL80211_ATTR_IE]))
+           !info->attrs[NL80211_ATTR_IE])
                return -EINVAL;
 
        memset(&ft_params, 0, sizeof(ft_params));
-- 
2.14.4

Reply via email to