Module: Mesa Branch: main Commit: 2b01934bc876d978080fd6812232dbb4c68f2a53 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=2b01934bc876d978080fd6812232dbb4c68f2a53
Author: Antonio Gomes <[email protected]> Date: Mon Apr 10 20:57:07 2023 -0300 rusticl: Move nir compilation to Program Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22434> --- src/gallium/frontends/rusticl/api/kernel.rs | 12 +-- src/gallium/frontends/rusticl/api/program.rs | 9 +- src/gallium/frontends/rusticl/core/kernel.rs | 130 ++++++++++---------------- src/gallium/frontends/rusticl/core/program.rs | 76 +++++++++++++++ 4 files changed, 136 insertions(+), 91 deletions(-) diff --git a/src/gallium/frontends/rusticl/api/kernel.rs b/src/gallium/frontends/rusticl/api/kernel.rs index 523a6f24e74..8a7e18067e9 100644 --- a/src/gallium/frontends/rusticl/api/kernel.rs +++ b/src/gallium/frontends/rusticl/api/kernel.rs @@ -170,11 +170,7 @@ pub fn create_kernel( return Err(CL_INVALID_KERNEL_DEFINITION); } - Ok(cl_kernel::from_arc(Kernel::new( - name, - p, - kernel_args.into_iter().next().unwrap(), - ))) + Ok(cl_kernel::from_arc(Kernel::new(name, p))) } pub fn create_kernels_in_program( @@ -207,11 +203,7 @@ pub fn create_kernels_in_program( unsafe { kernels .add(num_kernels as usize) - .write(cl_kernel::from_arc(Kernel::new( - name, - p.clone(), - kernel_args.into_iter().next().unwrap(), - ))); + .write(cl_kernel::from_arc(Kernel::new(name, p.clone()))); } } num_kernels += 1; diff --git a/src/gallium/frontends/rusticl/api/program.rs b/src/gallium/frontends/rusticl/api/program.rs index d987a9eb868..a4a4b4e06e9 100644 --- a/src/gallium/frontends/rusticl/api/program.rs +++ b/src/gallium/frontends/rusticl/api/program.rs @@ -235,7 +235,10 @@ pub fn create_program_with_binary( return Err(err); } - Ok(cl_program::from_arc(Program::from_bins(c, devs, &bins))) + let prog = Program::from_bins(c, devs, &bins); + prog.build_nirs(); + + Ok(cl_program::from_arc(prog)) //• CL_INVALID_BINARY if an invalid program binary was encountered for any device. binary_status will return specific status for each device. } @@ -289,6 +292,7 @@ pub fn build_program( //• CL_INVALID_OPERATION if program was not created with clCreateProgramWithSource, clCreateProgramWithIL or clCreateProgramWithBinary. if res { + p.build_nirs(); Ok(()) } else { if Platform::dbg().program { @@ -431,6 +435,9 @@ pub fn link_program( CL_LINK_PROGRAM_FAILURE }; + // Pre build nir kernels + res.build_nirs(); + let res = cl_program::from_arc(res); call_cb(pfn_notify, res, user_data); diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index ccc19030d87..b0fea9ef593 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -22,7 +22,6 @@ use rusticl_opencl_gen::*; use std::cell::RefCell; use std::cmp; use std::collections::HashMap; -use std::collections::HashSet; use std::convert::TryInto; use std::os::raw::c_void; use std::ptr; @@ -255,7 +254,7 @@ impl InternalKernelArg { } struct KernelDevStateInner { - nir: NirShader, + nir: Arc<NirShader>, constant_buffer: Option<Arc<PipeResource>>, cso: *mut c_void, info: pipe_compute_state_object_info, @@ -276,7 +275,7 @@ impl Drop for KernelDevState { } impl KernelDevState { - fn new(nirs: HashMap<Arc<Device>, NirShader>) -> Arc<Self> { + fn new(nirs: HashMap<Arc<Device>, Arc<NirShader>>) -> Arc<Self> { let states = nirs .into_iter() .map(|(dev, nir)| { @@ -736,94 +735,62 @@ fn deserialize_nir( Some((nir, args, internal_args)) } -fn convert_spirv_to_nir( +pub fn convert_spirv_to_nir( p: &Program, name: &str, - args: Vec<spirv::SPIRVKernelArg>, -) -> ( - HashMap<Arc<Device>, NirShader>, - Vec<KernelArg>, - Vec<InternalKernelArg>, - String, -) { - let mut nirs = HashMap::new(); - let mut args_set = HashSet::new(); - let mut internal_args_set = HashSet::new(); - let mut attributes_string_set = HashSet::new(); - - // TODO: we could run this in parallel? - for d in p.devs_with_build() { - let cache = d.screen().shader_cache(); - let key = p.hash_key(d, name); - - let res = if let Some(cache) = &cache { - cache.get(&mut key.unwrap()).and_then(|entry| { - let mut bin: &[u8] = &entry; - deserialize_nir(&mut bin, d) - }) - } else { - None - }; - - let (nir, args, internal_args) = if let Some(res) = res { - res - } else { - let mut nir = p.to_nir(name, d); - - /* this is a hack until we support fp16 properly and check for denorms inside - * vstore/vload_half - */ - nir.preserve_fp16_denorms(); + args: &[spirv::SPIRVKernelArg], + dev: &Arc<Device>, +) -> (NirShader, Vec<KernelArg>, Vec<InternalKernelArg>, String) { + let cache = dev.screen().shader_cache(); + let key = p.hash_key(dev, name); + + let res = if let Some(cache) = &cache { + cache.get(&mut key.unwrap()).and_then(|entry| { + let mut bin: &[u8] = &entry; + deserialize_nir(&mut bin, dev) + }) + } else { + None + }; - lower_and_optimize_nir_pre_inputs(d, &mut nir, &d.lib_clc); - let mut args = KernelArg::from_spirv_nir(&args, &mut nir); - let internal_args = lower_and_optimize_nir_late(d, &mut nir, &mut args); + let (nir, args, internal_args) = if let Some(res) = res { + res + } else { + let mut nir = p.to_nir(name, dev); - if let Some(cache) = cache { - let mut bin = Vec::new(); - let mut nir = nir.serialize(); + /* this is a hack until we support fp16 properly and check for denorms inside + * vstore/vload_half + */ + nir.preserve_fp16_denorms(); - bin.extend_from_slice(&nir.len().to_ne_bytes()); - bin.append(&mut nir); + lower_and_optimize_nir_pre_inputs(dev, &mut nir, &dev.lib_clc); + let mut args = KernelArg::from_spirv_nir(args, &mut nir); + let internal_args = lower_and_optimize_nir_late(dev, &mut nir, &mut args); - bin.extend_from_slice(&args.len().to_ne_bytes()); - for arg in &args { - bin.append(&mut arg.serialize()); - } + if let Some(cache) = cache { + let mut bin = Vec::new(); + let mut nir = nir.serialize(); - bin.extend_from_slice(&internal_args.len().to_ne_bytes()); - for arg in &internal_args { - bin.append(&mut arg.serialize()); - } + bin.extend_from_slice(&nir.len().to_ne_bytes()); + bin.append(&mut nir); - cache.put(&bin, &mut key.unwrap()); + bin.extend_from_slice(&args.len().to_ne_bytes()); + for arg in &args { + bin.append(&mut arg.serialize()); } - (nir, args, internal_args) - }; + bin.extend_from_slice(&internal_args.len().to_ne_bytes()); + for arg in &internal_args { + bin.append(&mut arg.serialize()); + } - args_set.insert(args); - internal_args_set.insert(internal_args); - nirs.insert(d.clone(), nir); - attributes_string_set.insert(p.attribute_str(name, d)); - } + cache.put(&bin, &mut key.unwrap()); + } - // we want the same (internal) args for every compiled kernel, for now - assert!(args_set.len() == 1); - assert!(internal_args_set.len() == 1); - assert!(attributes_string_set.len() == 1); - let args = args_set.into_iter().next().unwrap(); - let internal_args = internal_args_set.into_iter().next().unwrap(); - - // spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource API call - // the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty. - let attributes_string = if p.is_src() { - attributes_string_set.into_iter().next().unwrap() - } else { - String::new() + (nir, args, internal_args) }; - (nirs, args, internal_args, attributes_string) + (nir, args, internal_args, p.attribute_str(name, dev)) } fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] { @@ -835,9 +802,12 @@ fn extract<'a, const S: usize>(buf: &'a mut &[u8]) -> &'a [u8; S] { } impl Kernel { - pub fn new(name: String, prog: Arc<Program>, args: Vec<spirv::SPIRVKernelArg>) -> Arc<Kernel> { - let (mut nirs, args, internal_args, attributes_string) = - convert_spirv_to_nir(&prog, &name, args); + pub fn new(name: String, prog: Arc<Program>) -> Arc<Kernel> { + let nir_kernel_build = prog.get_nir_kernel_build(&name); + let mut nirs = nir_kernel_build.nirs; + let args = nir_kernel_build.args; + let internal_args = nir_kernel_build.internal_args; + let attributes_string = nir_kernel_build.attributes_string; let nir = nirs.values_mut().next().unwrap(); let wgs = nir.workgroup_size(); diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index afed3c8dea4..232b25e05ab 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -1,6 +1,7 @@ use crate::api::icd::*; use crate::core::context::*; use crate::core::device::*; +use crate::core::kernel::*; use crate::core::platform::Platform; use crate::impl_cl_type_trait; @@ -63,10 +64,19 @@ pub struct Program { pub kernel_count: AtomicU32, spec_constants: Mutex<HashMap<u32, nir_const_value>>, build: Mutex<ProgramBuild>, + nir_builds: Mutex<HashMap<String, NirKernelBuild>>, } impl_cl_type_trait!(cl_program, Program, CL_INVALID_PROGRAM); +#[derive(Clone)] +pub struct NirKernelBuild { + pub nirs: HashMap<Arc<Device>, Arc<NirShader>>, + pub args: Vec<KernelArg>, + pub internal_args: Vec<InternalKernelArg>, + pub attributes_string: String, +} + struct ProgramBuild { builds: HashMap<Arc<Device>, ProgramDevBuild>, kernels: Vec<String>, @@ -157,6 +167,7 @@ impl Program { builds: Self::create_default_builds(devs), kernels: Vec::new(), }), + nir_builds: Mutex::new(HashMap::new()), }) } @@ -229,6 +240,7 @@ impl Program { builds: builds, kernels: kernels.into_iter().collect(), }), + nir_builds: Mutex::new(HashMap::new()), }) } @@ -245,6 +257,7 @@ impl Program { builds: builds, kernels: Vec::new(), }), + nir_builds: Mutex::new(HashMap::new()), }) } @@ -259,6 +272,20 @@ impl Program { l.builds.get_mut(dev).unwrap() } + fn nir_build_info(&self) -> MutexGuard<HashMap<String, NirKernelBuild>> { + self.nir_builds.lock().unwrap() + } + + pub fn get_nir_kernel_build(&self, name: &str) -> NirKernelBuild { + let info = self.nir_build_info(); + info.get(name).unwrap().clone() + } + + pub fn set_nir_kernel_build(&self, name: &str, nir_build: NirKernelBuild) { + let mut info = self.nir_build_info(); + info.insert(String::from(name), nir_build); + } + pub fn status(&self, dev: &Arc<Device>) -> cl_build_status { Self::dev_build_info(&mut self.build_info(), dev).status } @@ -496,9 +523,58 @@ impl Program { builds: builds, kernels: kernels.into_iter().collect(), }), + nir_builds: Mutex::new(HashMap::new()), }) } + pub fn build_nir_kernel(&self, name: &str, args: Vec<spirv::SPIRVKernelArg>) -> NirKernelBuild { + let mut nirs = HashMap::new(); + let mut args_set = HashSet::new(); + let mut internal_args_set = HashSet::new(); + let mut attributes_string_set = HashSet::new(); + + // TODO: we could run this in parallel? + for d in self.devs_with_build() { + let (nir, args, internal_args, attributes_string) = + convert_spirv_to_nir(self, name, &args, d); + nirs.insert(d.clone(), Arc::new(nir)); + args_set.insert(args); + internal_args_set.insert(internal_args); + attributes_string_set.insert(attributes_string); + } + + // we want the same (internal) args for every compiled kernel, for now + assert!(args_set.len() == 1); + assert!(internal_args_set.len() == 1); + assert!(attributes_string_set.len() == 1); + let args = args_set.into_iter().next().unwrap(); + let internal_args = internal_args_set.into_iter().next().unwrap(); + + // spec: For kernels not created from OpenCL C source and the clCreateProgramWithSource API call + // the string returned from this query [CL_KERNEL_ATTRIBUTES] will be empty. + let attributes_string = if self.is_src() { + attributes_string_set.into_iter().next().unwrap() + } else { + String::new() + }; + + NirKernelBuild { + nirs: nirs, + args: args, + internal_args: internal_args, + attributes_string: attributes_string, + } + } + + pub fn build_nirs(&self) { + let devs = self.devs_with_build(); + for k in &self.kernels() { + let kernel_args: HashSet<_> = devs.iter().map(|d| self.args(d, k)).collect(); + let nir_build = self.build_nir_kernel(k, kernel_args.into_iter().next().unwrap()); + self.set_nir_kernel_build(k, nir_build); + } + } + pub(super) fn hash_key(&self, dev: &Arc<Device>, name: &str) -> Option<cache_key> { if let Some(cache) = dev.screen().shader_cache() { let mut lock = self.build_info();
