Module: Mesa
Branch: main
Commit: 76996e2a944a91c498b79198fbbb1df4cb2cff59
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=76996e2a944a91c498b79198fbbb1df4cb2cff59

Author: LingMan <18294-ling...@users.noreply.gitlab.freedesktop.org>
Date:   Mon Nov 13 05:29:47 2023 +0100

rusticl: Use the `from_raw_parts` wrappers

Deduplicates some safety checks and ensures we didn't forget one.

Reviewed-by: Karol Herbst <kher...@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26157>

---

 src/gallium/frontends/rusticl/api/memory.rs | 66 ++++++-----------------------
 src/gallium/frontends/rusticl/api/util.rs   |  1 -
 2 files changed, 14 insertions(+), 53 deletions(-)

diff --git a/src/gallium/frontends/rusticl/api/memory.rs 
b/src/gallium/frontends/rusticl/api/memory.rs
index 7221d0e13cb..4caf5ceb849 100644
--- a/src/gallium/frontends/rusticl/api/memory.rs
+++ b/src/gallium/frontends/rusticl/api/memory.rs
@@ -2466,11 +2466,6 @@ fn enqueue_svm_memcpy_impl(
         return Err(CL_INVALID_OPERATION);
     }
 
-    // CL_INVALID_VALUE if dst_ptr or src_ptr is NULL.
-    if dst_ptr.is_null() || src_ptr.is_null() {
-        return Err(CL_INVALID_VALUE);
-    }
-
     // CL_MEM_COPY_OVERLAP if the values specified for dst_ptr, src_ptr and 
size result in an
     // overlapping copy.
     let dst_ptr_addr = dst_ptr as usize;
@@ -2481,14 +2476,6 @@ fn enqueue_svm_memcpy_impl(
         return Err(CL_MEM_COPY_OVERLAP);
     }
 
-    // Not technically guaranteed by the OpenCL spec, but required by 
`from_raw_parts` below.
-    if isize::try_from(size).is_err()
-        || src_ptr_addr.checked_add(size).is_none()
-        || dst_ptr_addr.checked_add(size).is_none()
-    {
-        return Err(CL_INVALID_VALUE);
-    }
-
     // CAST: We have no idea about the type or initialization status of these 
bytes.
     // MaybeUninit<u8> is the safe bet.
     let src_ptr = src_ptr.cast::<MaybeUninit<u8>>();
@@ -2497,16 +2484,14 @@ fn enqueue_svm_memcpy_impl(
     // MaybeUninit<u8> is the safe bet.
     let dst_ptr = dst_ptr.cast::<MaybeUninit<u8>>();
 
-    // SAFETY: We've checked above that the pointer is not NULL, that the size 
isn't excessive, and
-    // that addr + size doesn't overflow. It is up to the application to 
ensure the memory is valid
-    // to read for `size` bytes and that it doesn't modify it until the 
command has completed.
-    let src = unsafe { slice::from_raw_parts(src_ptr, size) };
+    // SAFETY: It is up to the application to ensure the memory is valid to 
read for `size` bytes
+    // and that it doesn't modify it until the command has completed.
+    let src = unsafe { cl_slice::from_raw_parts(src_ptr, size)? };
 
-    // SAFETY: We've checked above that the pointer is not NULL, that the size 
isn't excessive, and
-    // that addr + size doesn't overflow. We've also ensured there's no 
aliasing between src and
-    // dst. It is up to the application to ensure the memory is valid to read 
and write for `size`
-    // bytes and that it doesn't modify or read from it until the command has 
completed.
-    let dst = unsafe { slice::from_raw_parts_mut(dst_ptr, size) };
+    // SAFETY: We've ensured there's no aliasing between src and dst. It is up 
to the application
+    // to ensure the memory is valid to read and write for `size` bytes and 
that it doesn't modify
+    // or read from it until the command has completed.
+    let dst = unsafe { cl_slice::from_raw_parts_mut(dst_ptr, size)? };
 
     create_and_queue(
         q,
@@ -2582,23 +2567,12 @@ fn enqueue_svm_mem_fill_impl(
 ) -> CLResult<()> {
     let q = command_queue.get_arc()?;
     let evs = event_list_from_cl(&q, num_events_in_wait_list, 
event_wait_list)?;
-    let svm_ptr_addr = svm_ptr as usize;
 
     // CL_INVALID_OPERATION if the device associated with command queue does 
not support SVM.
     if !q.device.svm_supported() {
         return Err(CL_INVALID_OPERATION);
     }
 
-    // CL_INVALID_VALUE if svm_ptr is NULL.
-    if svm_ptr.is_null() {
-        return Err(CL_INVALID_VALUE);
-    }
-
-    // CL_INVALID_VALUE if svm_ptr is not aligned to pattern_size bytes.
-    if svm_ptr_addr & (pattern_size - 1) != 0 {
-        return Err(CL_INVALID_VALUE);
-    }
-
     // CL_INVALID_VALUE if pattern is NULL [...]
     if pattern.is_null() {
         return Err(CL_INVALID_VALUE);
@@ -2609,11 +2583,6 @@ fn enqueue_svm_mem_fill_impl(
         return Err(CL_INVALID_VALUE);
     }
 
-    // Not technically guaranteed by the OpenCL spec, but required by 
`from_raw_parts_mut` below.
-    if isize::try_from(size).is_err() || 
svm_ptr_addr.checked_add(size).is_none() {
-        return Err(CL_INVALID_VALUE);
-    }
-
     // The provided `$bytesize` must equal `pattern_size`.
     macro_rules! generate_fill_closure {
         ($bytesize:literal) => {{
@@ -2672,16 +2641,12 @@ fn enqueue_svm_mem_fill_impl(
             // the same layout as `Pattern`.
             let svm_ptr = svm_ptr.cast::<MaybeUninit<Pattern>>();
 
-            // SAFETY: We've checked that `svm_ptr` is not NULL above. It is 
otherwise the calling
-            // application's responsibility to ensure that it is valid for 
reads and writes up to
-            // `size` bytes.
+            // SAFETY: It is the calling application's responsibility to 
ensure that `svm_ptr` is
+            // valid for reads and writes up to `size` bytes.
             // Since `pattern_size == mem::size_of::<Pattern>()` and 
`MaybeUninit<Pattern>` has the
             // same layout as `Pattern`, we know that
             // `size / pattern_size * mem::size_of<MaybeUninit<Pattern>>` 
equals `size`.
             //
-            // We've also checked that `svm_ptr` has an alignment of 
`pattern_size` which fulfills
-            // `Pattern`'s requirement.
-            //
             // Since we're creating a `&[MaybeUninit<Pattern>]` the 
initialization status does not
             // matter.
             //
@@ -2689,10 +2654,7 @@ fn enqueue_svm_mem_fill_impl(
             // particular, since we've made a copy of `pattern`, it doesn't 
matter if the memory
             // region referenced by `pattern` aliases the one referenced by 
this slice. It is up to
             // the application not to access it at all until this command has 
been completed.
-            //
-            // We've checked that `size` does not exceed `isize::MAX` and that 
`svm_ptr + size`
-            // does not overflow above.
-            let svm_slice = unsafe { slice::from_raw_parts_mut(svm_ptr, size / 
pattern_size) };
+            let svm_slice = unsafe { cl_slice::from_raw_parts_mut(svm_ptr, 
size / pattern_size)? };
 
             Box::new(move |_, _| {
                 for x in svm_slice {
@@ -2935,20 +2897,20 @@ fn enqueue_svm_migrate_mem(
         return Err(CL_INVALID_OPERATION);
     }
 
-    // CL_INVALID_VALUE if num_svm_pointers is zero or svm_pointers is NULL.
-    if num_svm_pointers == 0 || svm_pointers.is_null() {
+    // CL_INVALID_VALUE if num_svm_pointers is zero
+    if num_svm_pointers == 0 {
         return Err(CL_INVALID_VALUE);
     }
 
     let num_svm_pointers = num_svm_pointers as usize;
     // SAFETY: Just hoping the application is alright.
     let mut svm_pointers =
-        unsafe { slice::from_raw_parts(svm_pointers, num_svm_pointers) 
}.to_owned();
+        unsafe { cl_slice::from_raw_parts(svm_pointers, num_svm_pointers)? 
}.to_owned();
     // if sizes is NULL, every allocation containing the pointers need to be 
migrated
     let mut sizes = if sizes.is_null() {
         vec![0; num_svm_pointers]
     } else {
-        unsafe { slice::from_raw_parts(sizes, num_svm_pointers) }.to_owned()
+        unsafe { cl_slice::from_raw_parts(sizes, num_svm_pointers)? 
}.to_owned()
     };
 
     // CL_INVALID_VALUE if sizes[i] is non-zero range [svm_pointers[i], 
svm_pointers[i]+sizes[i]) is
diff --git a/src/gallium/frontends/rusticl/api/util.rs 
b/src/gallium/frontends/rusticl/api/util.rs
index 325ad12a027..934f46d848c 100644
--- a/src/gallium/frontends/rusticl/api/util.rs
+++ b/src/gallium/frontends/rusticl/api/util.rs
@@ -366,7 +366,6 @@ pub fn check_copy_overlap(
     true
 }
 
-#[allow(dead_code)]
 pub mod cl_slice {
     use crate::api::util::CLResult;
     use mesa_rust_util::ptr::addr;

Reply via email to