Module: Mesa
Branch: main
Commit: b69ec8bde3397551d901a76764af71483967480d
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=b69ec8bde3397551d901a76764af71483967480d

Author: Konstantin Seurer <[email protected]>
Date:   Tue Feb 21 11:45:09 2023 +0100

radv/rt: Refactor rq_load lowering

This just gets rid of all the bcsel emissions.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21455>

---

 src/amd/vulkan/radv_nir_lower_ray_queries.c | 92 +++++++++++++----------------
 1 file changed, 40 insertions(+), 52 deletions(-)

diff --git a/src/amd/vulkan/radv_nir_lower_ray_queries.c 
b/src/amd/vulkan/radv_nir_lower_ray_queries.c
index 831ab4ad8a7..1bf4fde25ad 100644
--- a/src/amd/vulkan/radv_nir_lower_ray_queries.c
+++ b/src/amd/vulkan/radv_nir_lower_ray_queries.c
@@ -428,32 +428,31 @@ lower_rq_initialize(nir_builder *b, nir_ssa_def *index, 
nir_intrinsic_instr *ins
 }
 
 static nir_ssa_def *
-lower_rq_load(nir_builder *b, nir_ssa_def *index, struct ray_query_vars *vars,
-              nir_ssa_def *committed, nir_ray_query_value value, unsigned 
column)
+lower_rq_load(nir_builder *b, nir_ssa_def *index, nir_intrinsic_instr *instr,
+              struct ray_query_vars *vars)
 {
+   assert(nir_src_is_const(instr->src[1]));
+   bool closest = nir_src_as_bool(instr->src[1]);
+   struct ray_query_intersection_vars *intersection = closest ? &vars->closest 
: &vars->candidate;
+
+   uint32_t column = nir_intrinsic_column(instr);
+
+   nir_ray_query_value value = nir_intrinsic_ray_query_value(instr);
    switch (value) {
    case nir_ray_query_value_flags:
       return rq_load_var(b, index, vars->flags);
    case nir_ray_query_value_intersection_barycentrics:
-      return nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.barycentrics),
-                       rq_load_var(b, index, vars->candidate.barycentrics));
+      return rq_load_var(b, index, intersection->barycentrics);
    case nir_ray_query_value_intersection_candidate_aabb_opaque:
       return nir_iand(b, rq_load_var(b, index, vars->candidate.opaque),
                       nir_ieq_imm(b, rq_load_var(b, index, 
vars->candidate.intersection_type),
                                   intersection_type_aabb));
    case nir_ray_query_value_intersection_front_face:
-      return nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.frontface),
-                       rq_load_var(b, index, vars->candidate.frontface));
+      return rq_load_var(b, index, intersection->frontface);
    case nir_ray_query_value_intersection_geometry_index:
-      return nir_iand_imm(
-         b,
-         nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.geometry_id_and_flags),
-                   rq_load_var(b, index, 
vars->candidate.geometry_id_and_flags)),
-         0xFFFFFF);
+      return nir_iand_imm(b, rq_load_var(b, index, 
intersection->geometry_id_and_flags), 0xFFFFFF);
    case nir_ray_query_value_intersection_instance_custom_index: {
-      nir_ssa_def *instance_node_addr =
-         nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.instance_addr),
-                   rq_load_var(b, index, vars->candidate.instance_addr));
+      nir_ssa_def *instance_node_addr = rq_load_var(b, index, 
intersection->instance_addr);
       return nir_iand_imm(b,
                           nir_build_load_global(b, 1, 32,
                                                 nir_iadd_imm(b, 
instance_node_addr,
@@ -462,39 +461,27 @@ lower_rq_load(nir_builder *b, nir_ssa_def *index, struct 
ray_query_vars *vars,
                           0xFFFFFF);
    }
    case nir_ray_query_value_intersection_instance_id: {
-      nir_ssa_def *instance_node_addr =
-         nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.instance_addr),
-                   rq_load_var(b, index, vars->candidate.instance_addr));
+      nir_ssa_def *instance_node_addr = rq_load_var(b, index, 
intersection->instance_addr);
       return nir_build_load_global(
          b, 1, 32,
          nir_iadd_imm(b, instance_node_addr, offsetof(struct 
radv_bvh_instance_node, instance_id)));
    }
    case nir_ray_query_value_intersection_instance_sbt_index:
-      return nir_iand_imm(
-         b,
-         nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.sbt_offset_and_flags),
-                   rq_load_var(b, index, 
vars->candidate.sbt_offset_and_flags)),
-         0xFFFFFF);
+      return nir_iand_imm(b, rq_load_var(b, index, 
intersection->sbt_offset_and_flags), 0xFFFFFF);
    case nir_ray_query_value_intersection_object_ray_direction: {
-      nir_ssa_def *instance_node_addr =
-         nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.instance_addr),
-                   rq_load_var(b, index, vars->candidate.instance_addr));
+      nir_ssa_def *instance_node_addr = rq_load_var(b, index, 
intersection->instance_addr);
       nir_ssa_def *wto_matrix[3];
       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
       return nir_build_vec3_mat_mult(b, rq_load_var(b, index, 
vars->direction), wto_matrix, false);
    }
    case nir_ray_query_value_intersection_object_ray_origin: {
-      nir_ssa_def *instance_node_addr =
-         nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.instance_addr),
-                   rq_load_var(b, index, vars->candidate.instance_addr));
+      nir_ssa_def *instance_node_addr = rq_load_var(b, index, 
intersection->instance_addr);
       nir_ssa_def *wto_matrix[3];
       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
       return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->origin), 
wto_matrix, true);
    }
    case nir_ray_query_value_intersection_object_to_world: {
-      nir_ssa_def *instance_node_addr =
-         nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.instance_addr),
-                   rq_load_var(b, index, vars->candidate.instance_addr));
+      nir_ssa_def *instance_node_addr = rq_load_var(b, index, 
intersection->instance_addr);
       nir_ssa_def *rows[3];
       for (unsigned r = 0; r < 3; ++r)
          rows[r] = nir_build_load_global(
@@ -506,19 +493,18 @@ lower_rq_load(nir_builder *b, nir_ssa_def *index, struct 
ray_query_vars *vars,
                       nir_channel(b, rows[2], column));
    }
    case nir_ray_query_value_intersection_primitive_index:
-      return nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.primitive_id),
-                       rq_load_var(b, index, vars->candidate.primitive_id));
+      return rq_load_var(b, index, intersection->primitive_id);
    case nir_ray_query_value_intersection_t:
-      return nir_bcsel(b, committed, rq_load_var(b, index, vars->closest.t),
-                       rq_load_var(b, index, vars->candidate.t));
-   case nir_ray_query_value_intersection_type:
-      return nir_bcsel(
-         b, committed, rq_load_var(b, index, vars->closest.intersection_type),
-         nir_iadd_imm(b, rq_load_var(b, index, 
vars->candidate.intersection_type), -1));
+      return rq_load_var(b, index, intersection->t);
+   case nir_ray_query_value_intersection_type: {
+      nir_ssa_def *intersection_type = rq_load_var(b, index, 
intersection->intersection_type);
+      if (!closest)
+         intersection_type = nir_iadd_imm(b, intersection_type, -1);
+
+      return intersection_type;
+   }
    case nir_ray_query_value_intersection_world_to_object: {
-      nir_ssa_def *instance_node_addr =
-         nir_bcsel(b, committed, rq_load_var(b, index, 
vars->closest.instance_addr),
-                   rq_load_var(b, index, vars->candidate.instance_addr));
+      nir_ssa_def *instance_node_addr = rq_load_var(b, index, 
intersection->instance_addr);
 
       nir_ssa_def *wto_matrix[3];
       nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
@@ -691,15 +677,19 @@ lower_rq_terminate(nir_builder *b, nir_ssa_def *index, 
nir_intrinsic_instr *inst
 bool
 radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device 
*device)
 {
-   bool contains_ray_query = false;
+   bool progress = false;
    struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL);
 
+   /* Run constant folding to collapse expressions that are required to be 
constant by the spec. */
+   NIR_PASS(progress, shader, nir_opt_constant_folding);
+
    nir_foreach_variable_in_list (var, &shader->variables) {
       if (!var->data.ray_query)
          continue;
 
       lower_ray_query(shader, var, query_ht, 
device->physical_device->max_shared_size);
-      contains_ray_query = true;
+
+      progress = true;
    }
 
    nir_foreach_function (function, shader) {
@@ -714,11 +704,9 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, 
struct radv_device *device
             continue;
 
          lower_ray_query(shader, var, query_ht, 
device->physical_device->max_shared_size);
-         contains_ray_query = true;
-      }
 
-      if (!contains_ray_query)
-         continue;
+         progress = true;
+      }
 
       nir_foreach_block (block, function->impl) {
          nir_foreach_instr_safe (instr, block) {
@@ -760,9 +748,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, 
struct radv_device *device
                lower_rq_initialize(&builder, index, intrinsic, vars);
                break;
             case nir_intrinsic_rq_load:
-               new_dest = lower_rq_load(&builder, index, vars, 
intrinsic->src[1].ssa,
-                                        
nir_intrinsic_ray_query_value(intrinsic),
-                                        nir_intrinsic_column(intrinsic));
+               new_dest = lower_rq_load(&builder, index, intrinsic, vars);
                break;
             case nir_intrinsic_rq_proceed:
                new_dest = lower_rq_proceed(&builder, index, vars, device);
@@ -779,6 +765,8 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, 
struct radv_device *device
 
             nir_instr_remove(instr);
             nir_instr_free(instr);
+
+            progress = true;
          }
       }
 
@@ -787,5 +775,5 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, 
struct radv_device *device
 
    ralloc_free(query_ht);
 
-   return contains_ray_query;
+   return progress;
 }

Reply via email to