A udata should be read only once per ioctl, not multiple times. Multiple reads make it unclear what the content is since userspace can change it between the reads.
Lift the ib_copy_validate_udata_in() out of alloc_srq_buf()/alloc_srq_db() and into hns_roce_create_srq(). Found by AI. Signed-off-by: Jason Gunthorpe <[email protected]> --- drivers/infiniband/hw/hns/hns_roce_srq.c | 35 +++++++++++------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/drivers/infiniband/hw/hns/hns_roce_srq.c b/drivers/infiniband/hw/hns/hns_roce_srq.c index 601f8cdfce96a3..cb848e8e6bbd76 100644 --- a/drivers/infiniband/hw/hns/hns_roce_srq.c +++ b/drivers/infiniband/hw/hns/hns_roce_srq.c @@ -340,22 +340,16 @@ static int set_srq_param(struct hns_roce_srq *srq, } static int alloc_srq_buf(struct hns_roce_dev *hr_dev, struct hns_roce_srq *srq, - struct ib_udata *udata) + struct ib_udata *udata, + struct hns_roce_ib_create_srq *ucmd) { - struct hns_roce_ib_create_srq ucmd = {}; int ret; - if (udata) { - ret = ib_copy_validate_udata_in(udata, ucmd, que_addr); - if (ret) - return ret; - } - - ret = alloc_srq_idx(hr_dev, srq, udata, ucmd.que_addr); + ret = alloc_srq_idx(hr_dev, srq, udata, ucmd->que_addr); if (ret) return ret; - ret = alloc_srq_wqe_buf(hr_dev, srq, udata, ucmd.buf_addr); + ret = alloc_srq_wqe_buf(hr_dev, srq, udata, ucmd->buf_addr); if (ret) goto err_idx; @@ -404,22 +398,18 @@ static void free_srq_db(struct hns_roce_dev *hr_dev, struct hns_roce_srq *srq, static int alloc_srq_db(struct hns_roce_dev *hr_dev, struct hns_roce_srq *srq, struct ib_udata *udata, + struct hns_roce_ib_create_srq *ucmd, struct hns_roce_ib_create_srq_resp *resp) { - struct hns_roce_ib_create_srq ucmd; struct hns_roce_ucontext *uctx; int ret; if (udata) { - ret = ib_copy_validate_udata_in(udata, ucmd, que_addr); - if (ret) - return ret; - if ((hr_dev->caps.flags & HNS_ROCE_CAP_FLAG_SRQ_RECORD_DB) && - (ucmd.req_cap_flags & HNS_ROCE_SRQ_CAP_RECORD_DB)) { + (ucmd->req_cap_flags & HNS_ROCE_SRQ_CAP_RECORD_DB)) { uctx = rdma_udata_to_drv_context(udata, struct hns_roce_ucontext, ibucontext); - ret = hns_roce_db_map_user(uctx, ucmd.db_addr, + ret = hns_roce_db_map_user(uctx, ucmd->db_addr, &srq->rdb); if (ret) return ret; @@ -448,6 +438,7 @@ int hns_roce_create_srq(struct ib_srq *ib_srq, struct hns_roce_dev *hr_dev = to_hr_dev(ib_srq->device); struct hns_roce_ib_create_srq_resp resp = {}; struct hns_roce_srq *srq = to_hr_srq(ib_srq); + struct hns_roce_ib_create_srq ucmd = {}; int ret; mutex_init(&srq->mutex); @@ -457,11 +448,17 @@ int hns_roce_create_srq(struct ib_srq *ib_srq, if (ret) goto err_out; - ret = alloc_srq_buf(hr_dev, srq, udata); + if (udata) { + ret = ib_copy_validate_udata_in(udata, ucmd, que_addr); + if (ret) + goto err_out; + } + + ret = alloc_srq_buf(hr_dev, srq, udata, &ucmd); if (ret) goto err_out; - ret = alloc_srq_db(hr_dev, srq, udata, &resp); + ret = alloc_srq_db(hr_dev, srq, udata, &ucmd, &resp); if (ret) goto err_srq_buf; -- 2.43.0
