Module: Mesa Branch: main Commit: 3a69424e09e0069d8f2ed04c7018cd17d66df743 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=3a69424e09e0069d8f2ed04c7018cd17d66df743
Author: Konstantin Seurer <konstantin.seu...@gmail.com> Date: Thu Jul 20 21:13:22 2023 +0200 radv/nir: Add radv_nir_lower_hit_attrib_derefs Move out the pass so it can be unit tested. Reviewed-by: Friedrich Vock <friedrich.v...@gmx.de> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24271> --- src/amd/vulkan/meson.build | 1 + src/amd/vulkan/nir/radv_nir.h | 2 + .../vulkan/nir/radv_nir_lower_hit_attrib_derefs.c | 116 +++++++++++++++++++++ src/amd/vulkan/radv_rt_shader.c | 110 +------------------ 4 files changed, 123 insertions(+), 106 deletions(-) diff --git a/src/amd/vulkan/meson.build b/src/amd/vulkan/meson.build index b72fcfcb0fe..36f664d35a9 100644 --- a/src/amd/vulkan/meson.build +++ b/src/amd/vulkan/meson.build @@ -79,6 +79,7 @@ libradv_files = files( 'nir/radv_nir_lower_cooperative_matrix.c', 'nir/radv_nir_lower_fs_barycentric.c', 'nir/radv_nir_lower_fs_intrinsics.c', + 'nir/radv_nir_lower_hit_attrib_derefs.c', 'nir/radv_nir_lower_intrinsics_early.c', 'nir/radv_nir_lower_io.c', 'nir/radv_nir_lower_poly_line_smooth.c', diff --git a/src/amd/vulkan/nir/radv_nir.h b/src/amd/vulkan/nir/radv_nir.h index 44a51d747e1..a93678dc214 100644 --- a/src/amd/vulkan/nir/radv_nir.h +++ b/src/amd/vulkan/nir/radv_nir.h @@ -50,6 +50,8 @@ void radv_nir_lower_abi(nir_shader *shader, enum amd_gfx_level gfx_level, const const struct radv_shader_args *args, const struct radv_pipeline_key *pl_key, uint32_t address32_hi); +bool radv_nir_lower_hit_attrib_derefs(nir_shader *shader); + bool radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device); bool radv_nir_lower_vs_inputs(nir_shader *shader, const struct radv_shader_stage *vs_stage, diff --git a/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c b/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c new file mode 100644 index 00000000000..601a40a81c6 --- /dev/null +++ b/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c @@ -0,0 +1,116 @@ +/* + * Copyright © 2021 Google + * Copyright © 2023 Valve Corporation + * SPDX-License-Identifier: MIT + */ + +#include "nir.h" +#include "nir_builder.h" +#include "radv_nir.h" + +static bool +lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) +{ + if (instr->type != nir_instr_type_intrinsic) + return false; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != nir_intrinsic_store_deref) + return false; + + nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); + if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) + return false; + + assert(deref->deref_type == nir_deref_type_var); + + b->cursor = nir_after_instr(instr); + + if (intrin->intrinsic == nir_intrinsic_load_deref) { + uint32_t num_components = intrin->def.num_components; + uint32_t bit_size = intrin->def.bit_size; + + nir_def *components[NIR_MAX_VEC_COMPONENTS]; + + for (uint32_t comp = 0; comp < num_components; comp++) { + uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8; + uint32_t base = offset / 4; + uint32_t comp_offset = offset % 4; + + if (bit_size == 64) { + components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base), + nir_load_hit_attrib_amd(b, .base = base + 1)); + } else if (bit_size == 32) { + components[comp] = nir_load_hit_attrib_amd(b, .base = base); + } else if (bit_size == 16) { + components[comp] = + nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2); + } else if (bit_size == 8) { + components[comp] = + nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset); + } else { + unreachable("Invalid bit_size"); + } + } + + nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, num_components)); + } else { + nir_def *value = intrin->src[1].ssa; + uint32_t num_components = value->num_components; + uint32_t bit_size = value->bit_size; + + for (uint32_t comp = 0; comp < num_components; comp++) { + uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8; + uint32_t base = offset / 4; + uint32_t comp_offset = offset % 4; + + nir_def *component = nir_channel(b, value, comp); + + if (bit_size == 64) { + nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base); + nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1); + } else if (bit_size == 32) { + nir_store_hit_attrib_amd(b, component, .base = base); + } else if (bit_size == 16) { + nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)); + nir_def *components[2]; + for (uint32_t word = 0; word < 2; word++) + components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word); + nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base); + } else if (bit_size == 8) { + nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8); + nir_def *components[4]; + for (uint32_t byte = 0; byte < 4; byte++) + components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte); + nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base); + } else { + unreachable("Invalid bit_size"); + } + } + } + + nir_instr_remove(instr); + return true; +} + +bool +radv_nir_lower_hit_attrib_derefs(nir_shader *shader) +{ + bool progress = false; + + progress |= nir_split_struct_vars(shader, nir_var_ray_hit_attrib); + progress |= nir_lower_indirect_derefs(shader, nir_var_ray_hit_attrib, UINT32_MAX); + progress |= nir_split_array_vars(shader, nir_var_ray_hit_attrib); + + progress |= nir_lower_vars_to_explicit_types(shader, nir_var_ray_hit_attrib, glsl_get_natural_size_align_bytes); + + progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref, + nir_metadata_block_index | nir_metadata_dominance, NULL); + + if (progress) { + nir_remove_dead_derefs(shader); + nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL); + } + + return progress; +} diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index a5a0131bebf..979e34d5317 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -26,6 +26,7 @@ #include "bvh/bvh.h" #include "meta/radv_meta.h" +#include "nir/radv_nir.h" #include "ac_nir.h" #include "radv_private.h" #include "radv_rt_common.h" @@ -578,104 +579,6 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, bool apply_ nir_shader_instructions_pass(shader, radv_lower_rt_instruction, nir_metadata_none, &data); } -static bool -lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) -{ - if (instr->type != nir_instr_type_intrinsic) - return false; - - nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != nir_intrinsic_store_deref) - return false; - - nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]); - if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) - return false; - - assert(deref->deref_type == nir_deref_type_var); - - b->cursor = nir_after_instr(instr); - - if (intrin->intrinsic == nir_intrinsic_load_deref) { - uint32_t num_components = intrin->def.num_components; - uint32_t bit_size = intrin->def.bit_size; - - nir_def *components[NIR_MAX_VEC_COMPONENTS]; - - for (uint32_t comp = 0; comp < num_components; comp++) { - uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8; - uint32_t base = offset / 4; - uint32_t comp_offset = offset % 4; - - if (bit_size == 64) { - components[comp] = nir_pack_64_2x32_split(b, nir_load_hit_attrib_amd(b, .base = base), - nir_load_hit_attrib_amd(b, .base = base + 1)); - } else if (bit_size == 32) { - components[comp] = nir_load_hit_attrib_amd(b, .base = base); - } else if (bit_size == 16) { - components[comp] = - nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)), comp_offset / 2); - } else if (bit_size == 8) { - components[comp] = - nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8), comp_offset); - } else { - unreachable("Invalid bit_size"); - } - } - - nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, num_components)); - } else { - nir_def *value = intrin->src[1].ssa; - uint32_t num_components = value->num_components; - uint32_t bit_size = value->bit_size; - - for (uint32_t comp = 0; comp < num_components; comp++) { - uint32_t offset = deref->var->data.driver_location + comp * bit_size / 8; - uint32_t base = offset / 4; - uint32_t comp_offset = offset % 4; - - nir_def *component = nir_channel(b, value, comp); - - if (bit_size == 64) { - nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, component), .base = base); - nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, component), .base = base + 1); - } else if (bit_size == 32) { - nir_store_hit_attrib_amd(b, component, .base = base); - } else if (bit_size == 16) { - nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, .base = base)); - nir_def *components[2]; - for (uint32_t word = 0; word < 2; word++) - components[word] = (word == comp_offset / 2) ? nir_channel(b, value, comp) : nir_channel(b, prev, word); - nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, components, 2)), .base = base); - } else if (bit_size == 8) { - nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, .base = base), 8); - nir_def *components[4]; - for (uint32_t byte = 0; byte < 4; byte++) - components[byte] = (byte == comp_offset) ? nir_channel(b, value, comp) : nir_channel(b, prev, byte); - nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, components, 4)), .base = base); - } else { - unreachable("Invalid bit_size"); - } - } - } - - nir_instr_remove(instr); - return true; -} - -static bool -lower_hit_attrib_derefs(nir_shader *shader) -{ - bool progress = nir_shader_instructions_pass(shader, lower_hit_attrib_deref, - nir_metadata_block_index | nir_metadata_dominance, NULL); - if (progress) { - nir_remove_dead_derefs(shader); - nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL); - } - - return progress; -} - /* Lowers hit attributes to registers or shared memory. If hit_attribs is NULL, attributes are * lowered to shared memory. */ static void @@ -802,16 +705,11 @@ radv_parse_rt_stage(struct radv_device *device, const VkPipelineShaderStageCreat nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, key, false); - NIR_PASS(_, shader, nir_split_struct_vars, nir_var_ray_hit_attrib); - NIR_PASS(_, shader, nir_lower_indirect_derefs, nir_var_ray_hit_attrib, UINT32_MAX); - NIR_PASS(_, shader, nir_split_array_vars, nir_var_ray_hit_attrib); - - NIR_PASS(_, shader, nir_lower_vars_to_explicit_types, - nir_var_function_temp | nir_var_shader_call_data | nir_var_ray_hit_attrib, + NIR_PASS(_, shader, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data, glsl_get_natural_size_align_bytes); NIR_PASS(_, shader, lower_rt_derefs); - NIR_PASS(_, shader, lower_hit_attrib_derefs); + NIR_PASS(_, shader, radv_nir_lower_hit_attrib_derefs); NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset); @@ -1485,7 +1383,7 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin radv_build_ray_traversal(device, b, &args); nir_metadata_preserve(nir_shader_get_entrypoint(b->shader), nir_metadata_none); - lower_hit_attrib_derefs(b->shader); + radv_nir_lower_hit_attrib_derefs(b->shader); /* Register storage for hit attributes */ nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_DWORDS];