Move the validation of the udata to the same function that copies it.

Signed-off-by: Jason Gunthorpe <[email protected]>
---
 drivers/infiniband/hw/mlx4/qp.c | 25 +++----------------------
 1 file changed, 3 insertions(+), 22 deletions(-)

diff --git a/drivers/infiniband/hw/mlx4/qp.c b/drivers/infiniband/hw/mlx4/qp.c
index deb1b0306aa7a1..40ddd723d7b549 100644
--- a/drivers/infiniband/hw/mlx4/qp.c
+++ b/drivers/infiniband/hw/mlx4/qp.c
@@ -854,7 +854,6 @@ static int create_rq(struct ib_pd *pd, struct 
ib_qp_init_attr *init_attr,
        unsigned long flags;
        int range_size;
        struct mlx4_ib_create_wq wq;
-       size_t copy_len;
        int shift;
        int n;
 
@@ -867,12 +866,9 @@ static int create_rq(struct ib_pd *pd, struct 
ib_qp_init_attr *init_attr,
 
        qp->state = IB_QPS_RESET;
 
-       copy_len = min(sizeof(struct mlx4_ib_create_wq), udata->inlen);
-
-       if (ib_copy_from_udata(&wq, udata, copy_len)) {
-               err = -EFAULT;
+       err = ib_copy_validate_udata_in(udata, wq, comp_mask);
+       if (err)
                goto err;
-       }
 
        if (wq.comp_mask || wq.reserved[0] || wq.reserved[1] ||
            wq.reserved[2]) {
@@ -4112,26 +4108,11 @@ struct ib_wq *mlx4_ib_create_wq(struct ib_pd *pd,
        struct mlx4_dev *dev = to_mdev(pd->device)->dev;
        struct ib_qp_init_attr ib_qp_init_attr = {};
        struct mlx4_ib_qp *qp;
-       struct mlx4_ib_create_wq ucmd;
-       int err, required_cmd_sz;
+       int err;
 
        if (!udata)
                return ERR_PTR(-EINVAL);
 
-       required_cmd_sz = offsetof(typeof(ucmd), comp_mask) +
-                         sizeof(ucmd.comp_mask);
-       if (udata->inlen < required_cmd_sz) {
-               pr_debug("invalid inlen\n");
-               return ERR_PTR(-EINVAL);
-       }
-
-       if (udata->inlen > sizeof(ucmd) &&
-           !ib_is_udata_cleared(udata, sizeof(ucmd),
-                                udata->inlen - sizeof(ucmd))) {
-               pr_debug("inlen is not supported\n");
-               return ERR_PTR(-EOPNOTSUPP);
-       }
-
        if (udata->outlen)
                return ERR_PTR(-EOPNOTSUPP);
 
-- 
2.43.0


Reply via email to