Module: Mesa
Branch: main
Commit: 6e2ba679ff7be6ecbe935049eca7d4eab32aa1e0
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=6e2ba679ff7be6ecbe935049eca7d4eab32aa1e0

Author: LingMan <[email protected]>
Date:   Thu Oct 12 21:31:31 2023 +0200

rusticl: add a safe abstraction to execute an SVMFreeCb

Reviewed-by: Karol Herbst <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25669>

---

 src/gallium/frontends/rusticl/api/memory.rs | 13 ++++---------
 src/gallium/frontends/rusticl/api/types.rs  | 17 +++++++++++++++++
 2 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/src/gallium/frontends/rusticl/api/memory.rs 
b/src/gallium/frontends/rusticl/api/memory.rs
index dfac9ff39aa..33941eb6298 100644
--- a/src/gallium/frontends/rusticl/api/memory.rs
+++ b/src/gallium/frontends/rusticl/api/memory.rs
@@ -2369,7 +2369,7 @@ fn enqueue_svm_free_impl(
     // The application is allowed to reuse or free the memory referenced by 
`svm_pointers` after this
     // function returns so we have to make a copy.
     // SAFETY: num_svm_pointers specifies the amount of elements in 
svm_pointers
-    let svm_pointers =
+    let mut svm_pointers =
         unsafe { slice::from_raw_parts(svm_pointers, num_svm_pointers as 
usize) }.to_vec();
     // SAFETY: The requirements on `SVMFreeCb::new` match the requirements
     // imposed by the OpenCL specification. It is the caller's duty to uphold 
them.
@@ -2382,15 +2382,10 @@ fn enqueue_svm_free_impl(
         event,
         false,
         Box::new(move |q, _| {
-            if let Some(cb) = &cb_opt {
-                let mut svm_pointers = svm_pointers.clone();
-                let ptr = svm_pointers.as_mut_ptr();
-                // SAFETY: it's undefined behavior if the application screws up
-                unsafe {
-                    (cb.func)(command_queue, num_svm_pointers, ptr, cb.data);
-                }
+            if let Some(cb) = cb_opt {
+                cb.call(q, &mut svm_pointers);
             } else {
-                for &ptr in &svm_pointers {
+                for ptr in svm_pointers {
                     svm_free_impl(&q.context, ptr);
                 }
             }
diff --git a/src/gallium/frontends/rusticl/api/types.rs 
b/src/gallium/frontends/rusticl/api/types.rs
index cad831cc264..0e8c7082cb3 100644
--- a/src/gallium/frontends/rusticl/api/types.rs
+++ b/src/gallium/frontends/rusticl/api/types.rs
@@ -3,6 +3,7 @@ use crate::api::icd::ReferenceCountedAPIPointer;
 use crate::core::context::Context;
 use crate::core::event::Event;
 use crate::core::memory::Mem;
+use crate::core::queue::Queue;
 
 use rusticl_opencl_gen::*;
 
@@ -172,6 +173,22 @@ cl_callback!(
     }
 );
 
+impl SVMFreeCb {
+    pub fn call(self, queue: &Queue, svm_pointers: &mut [*mut c_void]) {
+        let cl = cl_command_queue::from_ptr(queue);
+        // SAFETY: `cl` must be a valid pointer to an OpenCL queue, which is 
where we just got it from.
+        // All other requirements are covered by this callback's type 
invariants.
+        unsafe {
+            (self.func)(
+                cl,
+                svm_pointers.len() as u32,
+                svm_pointers.as_mut_ptr(),
+                self.data,
+            )
+        };
+    }
+}
+
 // a lot of APIs use 3 component vectors passed as C arrays
 #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
 pub struct CLVec<T> {

Reply via email to