Module: Mesa
Branch: main
Commit: 3a69424e09e0069d8f2ed04c7018cd17d66df743
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=3a69424e09e0069d8f2ed04c7018cd17d66df743

Author: Konstantin Seurer <konstantin.seu...@gmail.com>
Date:   Thu Jul 20 21:13:22 2023 +0200

radv/nir: Add radv_nir_lower_hit_attrib_derefs

Move out the pass so it can be unit tested.

Reviewed-by: Friedrich Vock <friedrich.v...@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24271>

---

 src/amd/vulkan/meson.build                         |   1 +
 src/amd/vulkan/nir/radv_nir.h                      |   2 +
 .../vulkan/nir/radv_nir_lower_hit_attrib_derefs.c  | 116 +++++++++++++++++++++
 src/amd/vulkan/radv_rt_shader.c                    | 110 +------------------
 4 files changed, 123 insertions(+), 106 deletions(-)

diff --git a/src/amd/vulkan/meson.build b/src/amd/vulkan/meson.build
index b72fcfcb0fe..36f664d35a9 100644
--- a/src/amd/vulkan/meson.build
+++ b/src/amd/vulkan/meson.build
@@ -79,6 +79,7 @@ libradv_files = files(
   'nir/radv_nir_lower_cooperative_matrix.c',
   'nir/radv_nir_lower_fs_barycentric.c',
   'nir/radv_nir_lower_fs_intrinsics.c',
+  'nir/radv_nir_lower_hit_attrib_derefs.c',
   'nir/radv_nir_lower_intrinsics_early.c',
   'nir/radv_nir_lower_io.c',
   'nir/radv_nir_lower_poly_line_smooth.c',
diff --git a/src/amd/vulkan/nir/radv_nir.h b/src/amd/vulkan/nir/radv_nir.h
index 44a51d747e1..a93678dc214 100644
--- a/src/amd/vulkan/nir/radv_nir.h
+++ b/src/amd/vulkan/nir/radv_nir.h
@@ -50,6 +50,8 @@ void radv_nir_lower_abi(nir_shader *shader, enum 
amd_gfx_level gfx_level, const
                         const struct radv_shader_args *args, const struct 
radv_pipeline_key *pl_key,
                         uint32_t address32_hi);
 
+bool radv_nir_lower_hit_attrib_derefs(nir_shader *shader);
+
 bool radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device 
*device);
 
 bool radv_nir_lower_vs_inputs(nir_shader *shader, const struct 
radv_shader_stage *vs_stage,
diff --git a/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c 
b/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c
new file mode 100644
index 00000000000..601a40a81c6
--- /dev/null
+++ b/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c
@@ -0,0 +1,116 @@
+/*
+ * Copyright © 2021 Google
+ * Copyright © 2023 Valve Corporation
+ * SPDX-License-Identifier: MIT
+ */
+
+#include "nir.h"
+#include "nir_builder.h"
+#include "radv_nir.h"
+
+static bool
+lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
+{
+   if (instr->type != nir_instr_type_intrinsic)
+      return false;
+
+   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+   if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != 
nir_intrinsic_store_deref)
+      return false;
+
+   nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
+   if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib))
+      return false;
+
+   assert(deref->deref_type == nir_deref_type_var);
+
+   b->cursor = nir_after_instr(instr);
+
+   if (intrin->intrinsic == nir_intrinsic_load_deref) {
+      uint32_t num_components = intrin->def.num_components;
+      uint32_t bit_size = intrin->def.bit_size;
+
+      nir_def *components[NIR_MAX_VEC_COMPONENTS];
+
+      for (uint32_t comp = 0; comp < num_components; comp++) {
+         uint32_t offset = deref->var->data.driver_location + comp * bit_size 
/ 8;
+         uint32_t base = offset / 4;
+         uint32_t comp_offset = offset % 4;
+
+         if (bit_size == 64) {
+            components[comp] = nir_pack_64_2x32_split(b, 
nir_load_hit_attrib_amd(b, .base = base),
+                                                      
nir_load_hit_attrib_amd(b, .base = base + 1));
+         } else if (bit_size == 32) {
+            components[comp] = nir_load_hit_attrib_amd(b, .base = base);
+         } else if (bit_size == 16) {
+            components[comp] =
+               nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, 
.base = base)), comp_offset / 2);
+         } else if (bit_size == 8) {
+            components[comp] =
+               nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, 
.base = base), 8), comp_offset);
+         } else {
+            unreachable("Invalid bit_size");
+         }
+      }
+
+      nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, 
num_components));
+   } else {
+      nir_def *value = intrin->src[1].ssa;
+      uint32_t num_components = value->num_components;
+      uint32_t bit_size = value->bit_size;
+
+      for (uint32_t comp = 0; comp < num_components; comp++) {
+         uint32_t offset = deref->var->data.driver_location + comp * bit_size 
/ 8;
+         uint32_t base = offset / 4;
+         uint32_t comp_offset = offset % 4;
+
+         nir_def *component = nir_channel(b, value, comp);
+
+         if (bit_size == 64) {
+            nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, 
component), .base = base);
+            nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, 
component), .base = base + 1);
+         } else if (bit_size == 32) {
+            nir_store_hit_attrib_amd(b, component, .base = base);
+         } else if (bit_size == 16) {
+            nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, 
.base = base));
+            nir_def *components[2];
+            for (uint32_t word = 0; word < 2; word++)
+               components[word] = (word == comp_offset / 2) ? nir_channel(b, 
value, comp) : nir_channel(b, prev, word);
+            nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, 
components, 2)), .base = base);
+         } else if (bit_size == 8) {
+            nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, 
.base = base), 8);
+            nir_def *components[4];
+            for (uint32_t byte = 0; byte < 4; byte++)
+               components[byte] = (byte == comp_offset) ? nir_channel(b, 
value, comp) : nir_channel(b, prev, byte);
+            nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, 
components, 4)), .base = base);
+         } else {
+            unreachable("Invalid bit_size");
+         }
+      }
+   }
+
+   nir_instr_remove(instr);
+   return true;
+}
+
+bool
+radv_nir_lower_hit_attrib_derefs(nir_shader *shader)
+{
+   bool progress = false;
+
+   progress |= nir_split_struct_vars(shader, nir_var_ray_hit_attrib);
+   progress |= nir_lower_indirect_derefs(shader, nir_var_ray_hit_attrib, 
UINT32_MAX);
+   progress |= nir_split_array_vars(shader, nir_var_ray_hit_attrib);
+
+   progress |= nir_lower_vars_to_explicit_types(shader, 
nir_var_ray_hit_attrib, glsl_get_natural_size_align_bytes);
+
+   progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref,
+                                            nir_metadata_block_index | 
nir_metadata_dominance, NULL);
+
+   if (progress) {
+      nir_remove_dead_derefs(shader);
+      nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL);
+   }
+
+   return progress;
+}
diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c
index a5a0131bebf..979e34d5317 100644
--- a/src/amd/vulkan/radv_rt_shader.c
+++ b/src/amd/vulkan/radv_rt_shader.c
@@ -26,6 +26,7 @@
 
 #include "bvh/bvh.h"
 #include "meta/radv_meta.h"
+#include "nir/radv_nir.h"
 #include "ac_nir.h"
 #include "radv_private.h"
 #include "radv_rt_common.h"
@@ -578,104 +579,6 @@ lower_rt_instructions(nir_shader *shader, struct 
rt_variables *vars, bool apply_
    nir_shader_instructions_pass(shader, radv_lower_rt_instruction, 
nir_metadata_none, &data);
 }
 
-static bool
-lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
-{
-   if (instr->type != nir_instr_type_intrinsic)
-      return false;
-
-   nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
-   if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != 
nir_intrinsic_store_deref)
-      return false;
-
-   nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
-   if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib))
-      return false;
-
-   assert(deref->deref_type == nir_deref_type_var);
-
-   b->cursor = nir_after_instr(instr);
-
-   if (intrin->intrinsic == nir_intrinsic_load_deref) {
-      uint32_t num_components = intrin->def.num_components;
-      uint32_t bit_size = intrin->def.bit_size;
-
-      nir_def *components[NIR_MAX_VEC_COMPONENTS];
-
-      for (uint32_t comp = 0; comp < num_components; comp++) {
-         uint32_t offset = deref->var->data.driver_location + comp * bit_size 
/ 8;
-         uint32_t base = offset / 4;
-         uint32_t comp_offset = offset % 4;
-
-         if (bit_size == 64) {
-            components[comp] = nir_pack_64_2x32_split(b, 
nir_load_hit_attrib_amd(b, .base = base),
-                                                      
nir_load_hit_attrib_amd(b, .base = base + 1));
-         } else if (bit_size == 32) {
-            components[comp] = nir_load_hit_attrib_amd(b, .base = base);
-         } else if (bit_size == 16) {
-            components[comp] =
-               nir_channel(b, nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, 
.base = base)), comp_offset / 2);
-         } else if (bit_size == 8) {
-            components[comp] =
-               nir_channel(b, nir_unpack_bits(b, nir_load_hit_attrib_amd(b, 
.base = base), 8), comp_offset);
-         } else {
-            unreachable("Invalid bit_size");
-         }
-      }
-
-      nir_def_rewrite_uses(&intrin->def, nir_vec(b, components, 
num_components));
-   } else {
-      nir_def *value = intrin->src[1].ssa;
-      uint32_t num_components = value->num_components;
-      uint32_t bit_size = value->bit_size;
-
-      for (uint32_t comp = 0; comp < num_components; comp++) {
-         uint32_t offset = deref->var->data.driver_location + comp * bit_size 
/ 8;
-         uint32_t base = offset / 4;
-         uint32_t comp_offset = offset % 4;
-
-         nir_def *component = nir_channel(b, value, comp);
-
-         if (bit_size == 64) {
-            nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_x(b, 
component), .base = base);
-            nir_store_hit_attrib_amd(b, nir_unpack_64_2x32_split_y(b, 
component), .base = base + 1);
-         } else if (bit_size == 32) {
-            nir_store_hit_attrib_amd(b, component, .base = base);
-         } else if (bit_size == 16) {
-            nir_def *prev = nir_unpack_32_2x16(b, nir_load_hit_attrib_amd(b, 
.base = base));
-            nir_def *components[2];
-            for (uint32_t word = 0; word < 2; word++)
-               components[word] = (word == comp_offset / 2) ? nir_channel(b, 
value, comp) : nir_channel(b, prev, word);
-            nir_store_hit_attrib_amd(b, nir_pack_32_2x16(b, nir_vec(b, 
components, 2)), .base = base);
-         } else if (bit_size == 8) {
-            nir_def *prev = nir_unpack_bits(b, nir_load_hit_attrib_amd(b, 
.base = base), 8);
-            nir_def *components[4];
-            for (uint32_t byte = 0; byte < 4; byte++)
-               components[byte] = (byte == comp_offset) ? nir_channel(b, 
value, comp) : nir_channel(b, prev, byte);
-            nir_store_hit_attrib_amd(b, nir_pack_32_4x8(b, nir_vec(b, 
components, 4)), .base = base);
-         } else {
-            unreachable("Invalid bit_size");
-         }
-      }
-   }
-
-   nir_instr_remove(instr);
-   return true;
-}
-
-static bool
-lower_hit_attrib_derefs(nir_shader *shader)
-{
-   bool progress = nir_shader_instructions_pass(shader, lower_hit_attrib_deref,
-                                                nir_metadata_block_index | 
nir_metadata_dominance, NULL);
-   if (progress) {
-      nir_remove_dead_derefs(shader);
-      nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL);
-   }
-
-   return progress;
-}
-
 /* Lowers hit attributes to registers or shared memory. If hit_attribs is 
NULL, attributes are
  * lowered to shared memory. */
 static void
@@ -802,16 +705,11 @@ radv_parse_rt_stage(struct radv_device *device, const 
VkPipelineShaderStageCreat
 
    nir_shader *shader = radv_shader_spirv_to_nir(device, &rt_stage, key, 
false);
 
-   NIR_PASS(_, shader, nir_split_struct_vars, nir_var_ray_hit_attrib);
-   NIR_PASS(_, shader, nir_lower_indirect_derefs, nir_var_ray_hit_attrib, 
UINT32_MAX);
-   NIR_PASS(_, shader, nir_split_array_vars, nir_var_ray_hit_attrib);
-
-   NIR_PASS(_, shader, nir_lower_vars_to_explicit_types,
-            nir_var_function_temp | nir_var_shader_call_data | 
nir_var_ray_hit_attrib,
+   NIR_PASS(_, shader, nir_lower_vars_to_explicit_types, nir_var_function_temp 
| nir_var_shader_call_data,
             glsl_get_natural_size_align_bytes);
 
    NIR_PASS(_, shader, lower_rt_derefs);
-   NIR_PASS(_, shader, lower_hit_attrib_derefs);
+   NIR_PASS(_, shader, radv_nir_lower_hit_attrib_derefs);
 
    NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp, 
nir_address_format_32bit_offset);
 
@@ -1485,7 +1383,7 @@ radv_build_traversal(struct radv_device *device, struct 
radv_ray_tracing_pipelin
    radv_build_ray_traversal(device, b, &args);
 
    nir_metadata_preserve(nir_shader_get_entrypoint(b->shader), 
nir_metadata_none);
-   lower_hit_attrib_derefs(b->shader);
+   radv_nir_lower_hit_attrib_derefs(b->shader);
 
    /* Register storage for hit attributes */
    nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_DWORDS];

Reply via email to