Module: Mesa Branch: main Commit: c1651a103268f707cb0e7d2ffd97f9d3878fbb23 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=c1651a103268f707cb0e7d2ffd97f9d3878fbb23
Author: Friedrich Vock <[email protected]> Date: Tue Feb 21 21:47:04 2023 +0100 radv: Extend hit attribute lowering for LDS Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21635> --- src/amd/vulkan/radv_rt_shader.c | 94 +++++++++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 36 deletions(-) diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index 9b3898f5a93..d8fc762ec05 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -766,6 +766,63 @@ lower_hit_attrib_derefs(nir_shader *shader) return progress; } +/* Lowers hit attributes to registers or shared memory. If hit_attribs is NULL, attributes are + * lowered to shared memory. */ +static void +lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs, uint32_t workgroup_size) +{ + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + + nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib) + attrib->data.mode = nir_var_shader_temp; + + nir_builder b; + nir_builder_init(&b, impl); + + nir_foreach_block (block, impl) { + nir_foreach_instr_safe (instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic != nir_intrinsic_load_hit_attrib_amd && + intrin->intrinsic != nir_intrinsic_store_hit_attrib_amd) + continue; + + b.cursor = nir_after_instr(instr); + + nir_ssa_def *offset; + if (!hit_attribs) + offset = nir_imul_imm(&b, + nir_iadd_imm(&b, nir_load_local_invocation_index(&b), + nir_intrinsic_base(intrin) * workgroup_size), + sizeof(uint32_t)); + + if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) { + nir_ssa_def *ret; + if (hit_attribs) + ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]); + else + ret = nir_load_shared(&b, 1, 32, offset, .base = 0, .align_mul = 4); + nir_ssa_def_rewrite_uses(nir_instr_ssa_def(instr), ret); + } else { + if (hit_attribs) + nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1); + else + nir_store_shared(&b, intrin->src->ssa, offset, .base = 0, .align_mul = 4); + } + nir_instr_remove(instr); + } + } + + if (hit_attribs) { + nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance); + + nir_lower_global_vars_to_local(shader); + nir_lower_vars_to_ssa(shader); + } +} + static void insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, nir_ssa_def *idx, uint32_t call_idx_base, uint32_t call_idx) @@ -1508,41 +1565,6 @@ move_rt_instructions(nir_shader *shader) nir_metadata_all & (~nir_metadata_instr_index)); } -static void -lower_hit_attribs(nir_shader *shader, nir_variable **hit_attribs) -{ - nir_function_impl *impl = nir_shader_get_entrypoint(shader); - - nir_foreach_variable_with_modes (attrib, shader, nir_var_ray_hit_attrib) - attrib->data.mode = nir_var_shader_temp; - - nir_builder b; - nir_builder_init(&b, impl); - - nir_foreach_block (block, impl) { - nir_foreach_instr_safe (instr, block) { - if (instr->type != nir_instr_type_intrinsic) - continue; - - b.cursor = nir_after_instr(instr); - nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - if (intrin->intrinsic == nir_intrinsic_load_hit_attrib_amd) { - nir_ssa_def *ret = nir_load_var(&b, hit_attribs[nir_intrinsic_base(intrin)]); - nir_ssa_def_rewrite_uses(nir_instr_ssa_def(instr), ret); - nir_instr_remove(instr); - } else if (intrin->intrinsic == nir_intrinsic_store_hit_attrib_amd) { - nir_store_var(&b, hit_attribs[nir_intrinsic_base(intrin)], intrin->src->ssa, 0x1); - nir_instr_remove(instr); - } - } - } - - nir_metadata_preserve(impl, nir_metadata_block_index | nir_metadata_dominance); - - nir_lower_global_vars_to_local(shader); - nir_lower_vars_to_ssa(shader); -} - nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_pipeline_shader_stack_size *stack_sizes, @@ -1645,7 +1667,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader)); nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none); - lower_hit_attribs(b.shader, hit_attribs); + lower_hit_attribs(b.shader, hit_attribs, device->physical_device->rt_wave_size); return b.shader; }
