Module: Mesa
Branch: main
Commit: 241d16c9e80880a6dc66694bdf285f9cf065dfd4
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=241d16c9e80880a6dc66694bdf285f9cf065dfd4

Author: LingMan <[email protected]>
Date:   Thu Oct 12 19:49:59 2023 +0200

rusticl: add a safe abstraction to execute an EventCB

Reviewed-by: Karol Herbst <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25669>

---

 src/gallium/frontends/rusticl/api/types.rs  | 10 ++++++++++
 src/gallium/frontends/rusticl/core/event.rs |  7 +++----
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/src/gallium/frontends/rusticl/api/types.rs 
b/src/gallium/frontends/rusticl/api/types.rs
index d117402c80c..0b5177acc70 100644
--- a/src/gallium/frontends/rusticl/api/types.rs
+++ b/src/gallium/frontends/rusticl/api/types.rs
@@ -1,6 +1,7 @@
 use crate::api::icd::CLResult;
 use crate::api::icd::ReferenceCountedAPIPointer;
 use crate::core::context::Context;
+use crate::core::event::Event;
 
 use rusticl_opencl_gen::*;
 
@@ -129,6 +130,15 @@ cl_callback!(
     }
 );
 
+impl EventCB {
+    pub fn call(self, event: &Event, status: cl_int) {
+        let cl = cl_event::from_ptr(event);
+        // SAFETY: `cl` must be a valid pointer to an OpenCL event, which is 
where we just got it from.
+        // All other requirements are covered by this callback's type 
invariants.
+        unsafe { (self.func)(cl, status, self.data) };
+    }
+}
+
 cl_callback!(
     MemCB(FuncMemCB) {
         memobj: cl_mem,
diff --git a/src/gallium/frontends/rusticl/core/event.rs 
b/src/gallium/frontends/rusticl/core/event.rs
index ecea0568f58..daff5c4e950 100644
--- a/src/gallium/frontends/rusticl/core/event.rs
+++ b/src/gallium/frontends/rusticl/core/event.rs
@@ -119,9 +119,8 @@ impl Event {
         }
 
         if [CL_COMPLETE, CL_RUNNING, CL_SUBMITTED].contains(&(new as u32)) {
-            if let Some(cbs) = lock.cbs.get(new as usize) {
-                cbs.iter()
-                    .for_each(|cb| unsafe { 
(cb.func)(cl_event::from_ptr(self), new, cb.data) });
+            if let Some(cbs) = lock.cbs.get_mut(new as usize) {
+                cbs.drain(..).for_each(|cb| cb.call(self, new));
             }
         }
     }
@@ -167,7 +166,7 @@ impl Event {
         // call cb if the status was already reached
         if state >= status {
             drop(lock);
-            unsafe { (cb.func)(cl_event::from_ptr(self), status, cb.data) };
+            cb.call(self, state);
         } else {
             lock.cbs.get_mut(state as usize).unwrap().push(cb);
         }

Reply via email to