Module: Mesa
Branch: main
Commit: b8171ab372cde0679cd3854aaa8b7512995915f3
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=b8171ab372cde0679cd3854aaa8b7512995915f3

Author: LingMan <[email protected]>
Date:   Fri Oct 13 23:20:43 2023 +0200

rusticl: use ProgramCB

Reviewed-by: Karol Herbst <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25669>

---

 src/gallium/frontends/rusticl/api/program.rs | 35 +++++++++++++++-------------
 src/gallium/frontends/rusticl/api/types.rs   |  3 +++
 src/gallium/frontends/rusticl/api/util.rs    | 10 --------
 3 files changed, 22 insertions(+), 26 deletions(-)

diff --git a/src/gallium/frontends/rusticl/api/program.rs 
b/src/gallium/frontends/rusticl/api/program.rs
index 9c4a8a5683f..1cef4581612 100644
--- a/src/gallium/frontends/rusticl/api/program.rs
+++ b/src/gallium/frontends/rusticl/api/program.rs
@@ -92,16 +92,6 @@ fn validate_devices<'a>(
     Ok(devs)
 }
 
-fn call_cb(
-    pfn_notify: Option<FuncProgramCB>,
-    program: cl_program,
-    user_data: *mut ::std::os::raw::c_void,
-) {
-    if let Some(cb) = pfn_notify {
-        unsafe { cb(program, user_data) };
-    }
-}
-
 #[cl_entrypoint]
 fn create_program_with_source(
     context: cl_context,
@@ -283,7 +273,9 @@ fn build_program(
     let p = program.get_ref()?;
     let devs = validate_devices(device_list, num_devices, &p.devs)?;
 
-    check_cb(&pfn_notify, user_data)?;
+    // SAFETY: The requirements on `ProgramCB::try_new` match the requirements
+    // imposed by the OpenCL specification. It is the caller's duty to uphold 
them.
+    let cb_opt = unsafe { ProgramCB::try_new(pfn_notify, user_data)? };
 
     // CL_INVALID_OPERATION if there are kernel objects attached to program.
     if p.active_kernels() {
@@ -296,7 +288,9 @@ fn build_program(
         res &= p.build(dev, c_string_to_string(options));
     }
 
-    call_cb(pfn_notify, program, user_data);
+    if let Some(cb) = cb_opt {
+        unsafe { (cb.func)(program, cb.data) };
+    }
 
     //• CL_INVALID_BINARY if program is created with clCreateProgramWithBinary 
and devices listed in device_list do not have a valid program binary loaded.
     //• CL_INVALID_BUILD_OPTIONS if the build options specified by options are 
invalid.
@@ -331,7 +325,9 @@ fn compile_program(
     let p = program.get_ref()?;
     let devs = validate_devices(device_list, num_devices, &p.devs)?;
 
-    check_cb(&pfn_notify, user_data)?;
+    // SAFETY: The requirements on `ProgramCB::try_new` match the requirements
+    // imposed by the OpenCL specification. It is the caller's duty to uphold 
them.
+    let cb_opt = unsafe { ProgramCB::try_new(pfn_notify, user_data)? };
 
     // CL_INVALID_VALUE if num_input_headers is zero and header_include_names 
or input_headers are
     // not NULL or if num_input_headers is not zero and header_include_names 
or input_headers are
@@ -378,7 +374,9 @@ fn compile_program(
         res &= p.compile(dev, c_string_to_string(options), &headers);
     }
 
-    call_cb(pfn_notify, program, user_data);
+    if let Some(cb) = cb_opt {
+        unsafe { (cb.func)(program, cb.data) };
+    }
 
     // • CL_INVALID_COMPILER_OPTIONS if the compiler options specified by 
options are invalid.
     // • CL_INVALID_OPERATION if the compilation or build of a program 
executable for any of the devices listed in device_list by a previous call to 
clCompileProgram or clBuildProgram for program has not completed.
@@ -409,7 +407,9 @@ pub fn link_program(
     let devs = validate_devices(device_list, num_devices, &c.devs)?;
     let progs = cl_program::get_arc_vec_from_arr(input_programs, 
num_input_programs)?;
 
-    check_cb(&pfn_notify, user_data)?;
+    // SAFETY: The requirements on `ProgramCB::try_new` match the requirements
+    // imposed by the OpenCL specification. It is the caller's duty to uphold 
them.
+    let cb_opt = unsafe { ProgramCB::try_new(pfn_notify, user_data)? };
 
     // CL_INVALID_VALUE if num_input_programs is zero and input_programs is 
NULL
     if progs.is_empty() {
@@ -449,7 +449,10 @@ pub fn link_program(
 
     let res = cl_program::from_arc(res);
 
-    call_cb(pfn_notify, res, user_data);
+    if let Some(cb) = cb_opt {
+        unsafe { (cb.func)(res, cb.data) };
+    }
+
     Ok((res, code))
 
     //• CL_INVALID_LINKER_OPTIONS if the linker options specified by options 
are invalid.
diff --git a/src/gallium/frontends/rusticl/api/types.rs 
b/src/gallium/frontends/rusticl/api/types.rs
index 26072253353..2e9b518df54 100644
--- a/src/gallium/frontends/rusticl/api/types.rs
+++ b/src/gallium/frontends/rusticl/api/types.rs
@@ -51,11 +51,14 @@ macro_rules! cl_callback {
             ///   [`clSetEventCallback`] in the OpenCL specification.
             /// - MemCB: `func` must be soundly callable as documented on
             ///   [`clSetMemObjectDestructorCallback`] in the OpenCL 
specification.
+            /// - ProgramCB: `func` must be soundly callable as documented on
+            ///   [`clBuildProgram`] in the OpenCL specification.
             ///
             /// [`clCreateContext`]: 
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#clCreateContext
             /// [`clSetContextDestructorCallback`]: 
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#clSetContextDestructorCallback
             /// [`clSetEventCallback`]: 
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#clSetEventCallback
             /// [`clSetMemObjectDestructorCallback`]: 
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#clSetMemObjectDestructorCallback
+            /// [`clBuildProgram`]: 
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_API.html#clBuildProgram
             pub unsafe fn new(func: Option<$fn_alias>, data: *mut c_void) -> 
CLResult<Self> {
                 let Some(func) = func else {
                     return Err(CL_INVALID_VALUE);
diff --git a/src/gallium/frontends/rusticl/api/util.rs 
b/src/gallium/frontends/rusticl/api/util.rs
index 39c8084de01..2ec573af805 100644
--- a/src/gallium/frontends/rusticl/api/util.rs
+++ b/src/gallium/frontends/rusticl/api/util.rs
@@ -13,7 +13,6 @@ use std::ffi::CStr;
 use std::ffi::CString;
 use std::mem::{size_of, MaybeUninit};
 use std::ops::BitAnd;
-use std::os::raw::c_void;
 use std::slice;
 use std::sync::Arc;
 
@@ -294,15 +293,6 @@ pub fn to_maybeuninit_vec<T: Copy>(v: Vec<T>) -> 
Vec<MaybeUninit<T>> {
     v.into_iter().map(MaybeUninit::new).collect()
 }
 
-pub fn check_cb<T>(cb: &Option<T>, user_data: *mut c_void) -> CLResult<()> {
-    // CL_INVALID_VALUE if pfn_notify is NULL but user_data is not NULL.
-    if cb.is_none() && !user_data.is_null() {
-        return Err(CL_INVALID_VALUE);
-    }
-
-    Ok(())
-}
-
 pub fn checked_compare(a: usize, o: cmp::Ordering, b: u64) -> bool {
     if usize::BITS > u64::BITS {
         a.cmp(&(b as usize)) == o

Reply via email to