Module: Mesa
Branch: main
Commit: c5c4cc4137161786ed3837d29b30d3cdac77377e
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=c5c4cc4137161786ed3837d29b30d3cdac77377e

Author: LingMan <[email protected]>
Date:   Wed Oct 11 23:11:52 2023 +0200

rusticl: use EventCB

Reviewed-by: Karol Herbst <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25669>

---

 src/gallium/frontends/rusticl/api/event.rs  | 18 +++++++-----------
 src/gallium/frontends/rusticl/api/types.rs  |  3 +++
 src/gallium/frontends/rusticl/core/event.rs | 11 +++++------
 3 files changed, 15 insertions(+), 17 deletions(-)

diff --git a/src/gallium/frontends/rusticl/api/event.rs 
b/src/gallium/frontends/rusticl/api/event.rs
index 0b1b2cdb914..1b50958bee0 100644
--- a/src/gallium/frontends/rusticl/api/event.rs
+++ b/src/gallium/frontends/rusticl/api/event.rs
@@ -120,20 +120,16 @@ fn set_event_callback(
 ) -> CLResult<()> {
     let e = event.get_ref()?;
 
-    // CL_INVALID_VALUE if pfn_event_notify is NULL
-    // or if command_exec_callback_type is not CL_SUBMITTED, CL_RUNNING, or 
CL_COMPLETE.
-    if pfn_event_notify.is_none()
-        || ![CL_SUBMITTED, CL_RUNNING, CL_COMPLETE]
-            .contains(&(command_exec_callback_type as cl_uint))
-    {
+    // CL_INVALID_VALUE [...] if command_exec_callback_type is not 
CL_SUBMITTED, CL_RUNNING, or CL_COMPLETE.
+    if ![CL_SUBMITTED, CL_RUNNING, 
CL_COMPLETE].contains(&(command_exec_callback_type as cl_uint)) {
         return Err(CL_INVALID_VALUE);
     }
 
-    e.add_cb(
-        command_exec_callback_type,
-        pfn_event_notify.unwrap(),
-        user_data,
-    );
+    // SAFETY: The requirements on `EventCB::new` match the requirements
+    // imposed by the OpenCL specification. It is the caller's duty to uphold 
them.
+    let cb = unsafe { EventCB::new(pfn_event_notify, user_data)? };
+
+    e.add_cb(command_exec_callback_type, cb);
 
     Ok(())
 }
diff --git a/src/gallium/frontends/rusticl/api/types.rs 
b/src/gallium/frontends/rusticl/api/types.rs
index 9f4f0c35ef8..544d674b532 100644
--- a/src/gallium/frontends/rusticl/api/types.rs
+++ b/src/gallium/frontends/rusticl/api/types.rs
@@ -47,9 +47,12 @@ macro_rules! cl_callback {
             ///   [`clCreateContext`] in the OpenCL specification.
             /// - DeleteContextCB: `func` must be soundly callable as 
documented on
             ///   [`clSetContextDestructorCallback`] in the OpenCL 
specification.
+            /// - EventCB: `func` must be soundly callable as documented on
+            ///   [`clSetEventCallback`] in the OpenCL specification.
             ///
             /// [`clCreateContext`]: 
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#clCreateContext
             /// [`clSetContextDestructorCallback`]: 
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#clSetContextDestructorCallback
+            /// [`clSetEventCallback`]: 
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#clSetEventCallback
             pub unsafe fn new(func: Option<$fn_alias>, data: *mut c_void) -> 
CLResult<Self> {
                 let Some(func) = func else {
                     return Err(CL_INVALID_VALUE);
diff --git a/src/gallium/frontends/rusticl/core/event.rs 
b/src/gallium/frontends/rusticl/core/event.rs
index e51464710e6..44520928cc7 100644
--- a/src/gallium/frontends/rusticl/core/event.rs
+++ b/src/gallium/frontends/rusticl/core/event.rs
@@ -11,7 +11,6 @@ use mesa_rust_util::static_assert;
 use rusticl_opencl_gen::*;
 
 use std::collections::HashSet;
-use std::os::raw::c_void;
 use std::slice;
 use std::sync::Arc;
 use std::sync::Condvar;
@@ -37,7 +36,7 @@ pub enum EventTimes {
 #[derive(Default)]
 struct EventMutState {
     status: cl_int,
-    cbs: [Vec<(FuncEventCB, *mut c_void)>; 3],
+    cbs: [Vec<EventCB>; 3],
     work: Option<EventSig>,
     time_queued: cl_ulong,
     time_submit: cl_ulong,
@@ -122,7 +121,7 @@ impl Event {
         if [CL_COMPLETE, CL_RUNNING, CL_SUBMITTED].contains(&(new as u32)) {
             if let Some(cbs) = lock.cbs.get(new as usize) {
                 cbs.iter()
-                    .for_each(|(cb, data)| unsafe { 
cb(cl_event::from_ptr(self), new, *data) });
+                    .for_each(|cb| unsafe { 
(cb.func)(cl_event::from_ptr(self), new, cb.data) });
             }
         }
     }
@@ -161,16 +160,16 @@ impl Event {
         }
     }
 
-    pub fn add_cb(&self, state: cl_int, cb: FuncEventCB, data: *mut c_void) {
+    pub fn add_cb(&self, state: cl_int, cb: EventCB) {
         let mut lock = self.state();
         let status = lock.status;
 
         // call cb if the status was already reached
         if state >= status {
             drop(lock);
-            unsafe { cb(cl_event::from_ptr(self), status, data) };
+            unsafe { (cb.func)(cl_event::from_ptr(self), status, cb.data) };
         } else {
-            lock.cbs.get_mut(state as usize).unwrap().push((cb, data));
+            lock.cbs.get_mut(state as usize).unwrap().push(cb);
         }
     }
 

Reply via email to