Module: Mesa
Branch: main
Commit: 913de78731aa58c27861e62d16f50dc9249be58f
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=913de78731aa58c27861e62d16f50dc9249be58f

Author: Bas Nieuwenhuizen <[email protected]>
Date:   Wed Jan 11 01:30:24 2023 +0100

radv: Use provided handles for switch cases in RT shaders.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21406>

---

 src/amd/vulkan/radv_pipeline_rt.c |  3 +-
 src/amd/vulkan/radv_rt_shader.c   | 86 ++++++++++++++++++++++++++++-----------
 src/amd/vulkan/radv_shader.h      |  2 +
 3 files changed, 66 insertions(+), 25 deletions(-)

diff --git a/src/amd/vulkan/radv_pipeline_rt.c 
b/src/amd/vulkan/radv_pipeline_rt.c
index 02c783295df..197dcd9f6c6 100644
--- a/src/amd/vulkan/radv_pipeline_rt.c
+++ b/src/amd/vulkan/radv_pipeline_rt.c
@@ -435,7 +435,8 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache 
_cache,
          goto pipeline_fail;
       }
 
-      shader = create_rt_shader(device, &local_create_info, 
rt_pipeline->stack_sizes, &key);
+      shader = create_rt_shader(device, &local_create_info, 
rt_pipeline->stack_sizes,
+                                rt_pipeline->group_handles, &key);
       module.nir = shader;
       result = radv_compute_pipeline_compile(
          &rt_pipeline->base, pipeline_layout, device, cache, &key, &stage, 
pCreateInfo->flags,
diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c
index 212ac0e6110..a301dd5c033 100644
--- a/src/amd/vulkan/radv_rt_shader.c
+++ b/src/amd/vulkan/radv_rt_shader.c
@@ -1078,10 +1078,20 @@ init_traversal_vars(nir_builder *b)
    return ret;
 }
 
+struct traversal_data {
+   struct radv_device *device;
+   const VkRayTracingPipelineCreateInfoKHR *createInfo;
+   struct rt_variables *vars;
+   struct rt_traversal_vars *trav_vars;
+   nir_variable *barycentrics;
+
+   const struct radv_pipeline_group_handle *handles;
+};
+
 static void
 visit_any_hit_shaders(struct radv_device *device,
                       const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, 
nir_builder *b,
-                      struct rt_variables *vars)
+                      struct traversal_data *data, struct rt_variables *vars)
 {
    nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
 
@@ -1102,25 +1112,26 @@ visit_any_hit_shaders(struct radv_device *device,
       if (shader_id == VK_SHADER_UNUSED_KHR)
          continue;
 
+      /* Avoid emitting stages with the same shaders/handles multiple times. */
+      bool is_dup = false;
+      for (unsigned j = 0; j < i; ++j)
+         if (data->handles[j].any_hit_index == data->handles[i].any_hit_index)
+            is_dup = true;
+
+      if (is_dup)
+         continue;
+
       const VkPipelineShaderStageCreateInfo *stage = 
&pCreateInfo->pStages[shader_id];
       nir_shader *nir_stage = parse_rt_stage(device, stage, vars->key);
 
       vars->stage_idx = shader_id;
-      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
+      insert_rt_case(b, nir_stage, vars, sbt_idx, 0, 
data->handles[i].any_hit_index);
    }
 
    if (!(vars->create_info->flags & 
VK_PIPELINE_CREATE_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR))
       nir_pop_if(b, NULL);
 }
 
-struct traversal_data {
-   struct radv_device *device;
-   const VkRayTracingPipelineCreateInfoKHR *createInfo;
-   struct rt_variables *vars;
-   struct rt_traversal_vars *trav_vars;
-   nir_variable *barycentrics;
-};
-
 static void
 handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection 
*intersection,
                           const struct radv_ray_traversal_args *args,
@@ -1158,7 +1169,7 @@ handle_candidate_triangle(nir_builder *b, struct 
radv_triangle_intersection *int
 
       load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
 
-      visit_any_hit_shaders(data->device, data->createInfo, b, &inner_vars);
+      visit_any_hit_shaders(data->device, data->createInfo, b, args->data, 
&inner_vars);
 
       nir_push_if(b, nir_inot(b, nir_load_var(b, data->vars->ahit_accept)));
       {
@@ -1237,6 +1248,15 @@ handle_candidate_aabb(nir_builder *b, struct 
radv_leaf_intersection *intersectio
       if (shader_id == VK_SHADER_UNUSED_KHR)
          continue;
 
+      /* Avoid emitting stages with the same shaders/handles multiple times. */
+      bool is_dup = false;
+      for (unsigned j = 0; j < i; ++j)
+         if (data->handles[j].intersection_index == 
data->handles[i].intersection_index)
+            is_dup = true;
+
+      if (is_dup)
+         continue;
+
       const VkPipelineShaderStageCreateInfo *stage = 
&data->createInfo->pStages[shader_id];
       nir_shader *nir_stage = parse_rt_stage(data->device, stage, 
data->vars->key);
 
@@ -1250,7 +1270,8 @@ handle_candidate_aabb(nir_builder *b, struct 
radv_leaf_intersection *intersectio
       }
 
       inner_vars.stage_idx = shader_id;
-      insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, 
inner_vars.idx), 0, i + 2);
+      insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, 
inner_vars.idx), 0,
+                     data->handles[i].intersection_index);
    }
 
    if (!(data->vars->create_info->flags &
@@ -1297,6 +1318,7 @@ static nir_shader *
 build_traversal_shader(struct radv_device *device,
                        const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                        struct radv_pipeline_shader_stack_size *stack_sizes,
+                       const struct radv_pipeline_group_handle *handles,
                        const struct radv_pipeline_key *key)
 {
    /* Create the traversal shader as an intersection shader to prevent 
validation failures due to
@@ -1383,6 +1405,7 @@ build_traversal_shader(struct radv_device *device,
       .vars = &vars,
       .trav_vars = &trav_vars,
       .barycentrics = barycentrics,
+      .handles = handles,
    };
 
    struct radv_ray_traversal_args args = {
@@ -1518,6 +1541,7 @@ lower_hit_attribs(nir_shader *shader, nir_variable 
**hit_attribs)
 nir_shader *
 create_rt_shader(struct radv_device *device, const 
VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                  struct radv_pipeline_shader_stack_size *stack_sizes,
+                 const struct radv_pipeline_group_handle *handles,
                  const struct radv_pipeline_key *key)
 {
    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, 
"rt_combined");
@@ -1554,23 +1578,37 @@ create_rt_shader(struct radv_device *device, const 
VkRayTracingPipelineCreateInf
    nir_ssa_def *idx = nir_load_var(&b, vars.idx);
 
    /* Insert traversal shader */
-   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, 
stack_sizes, key);
+   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, 
stack_sizes, handles, key);
    assert(b.shader->info.shared_size == 0);
    b.shader->info.shared_size = traversal->info.shared_size;
    assert(b.shader->info.shared_size <= 32768);
    insert_rt_case(&b, traversal, &vars, idx, 0, 1);
 
-   /* We do a trick with the indexing of the resume shaders so that the first
-    * shader of stage x always gets id x and the resume shader ids then come 
after
-    * stageCount. This makes the shadergroup handles independent of 
compilation. */
-   unsigned call_idx_base = pCreateInfo->stageCount + 1;
-   for (unsigned i = 0; i < pCreateInfo->stageCount; ++i) {
-      const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[i];
-      gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
-      if (type != MESA_SHADER_RAYGEN && type != MESA_SHADER_CALLABLE &&
-          type != MESA_SHADER_CLOSEST_HIT && type != MESA_SHADER_MISS)
+   unsigned call_idx_base = 1;
+   for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
+      unsigned stage_idx = VK_SHADER_UNUSED_KHR;
+      if (pCreateInfo->pGroups[i].type == 
VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR)
+         stage_idx = pCreateInfo->pGroups[i].generalShader;
+      else
+         stage_idx = pCreateInfo->pGroups[i].closestHitShader;
+
+      if (stage_idx == VK_SHADER_UNUSED_KHR)
          continue;
 
+      /* Avoid emitting stages with the same shaders/handles multiple times. */
+      bool is_dup = false;
+      for (unsigned j = 0; j < i; ++j)
+         if (handles[j].general_index == handles[i].general_index)
+            is_dup = true;
+
+      if (is_dup)
+         continue;
+
+      const VkPipelineShaderStageCreateInfo *stage = 
&pCreateInfo->pStages[stage_idx];
+      ASSERTED gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
+      assert(type == MESA_SHADER_RAYGEN || type == MESA_SHADER_CALLABLE ||
+             type == MESA_SHADER_CLOSEST_HIT || type == MESA_SHADER_MISS);
+
       nir_shader *nir_stage = parse_rt_stage(device, stage, key);
 
       /* Move ray tracing system values to the top that are set by rt_trace_ray
@@ -1588,8 +1626,8 @@ create_rt_shader(struct radv_device *device, const 
VkRayTracingPipelineCreateInf
       nir_shader **resume_shaders = NULL;
       nir_lower_shader_calls(nir_stage, &opts, &resume_shaders, 
&num_resume_shaders, nir_stage);
 
-      vars.stage_idx = i;
-      insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, i + 2);
+      vars.stage_idx = stage_idx;
+      insert_rt_case(&b, nir_stage, &vars, idx, call_idx_base, 
handles[i].general_index);
       for (unsigned j = 0; j < num_resume_shaders; ++j) {
          insert_rt_case(&b, resume_shaders[j], &vars, idx, call_idx_base, 
call_idx_base + 1 + j);
       }
diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h
index f928e3d3a75..d9d69708ff9 100644
--- a/src/amd/vulkan/radv_shader.h
+++ b/src/amd/vulkan/radv_shader.h
@@ -47,6 +47,7 @@ struct radv_physical_device;
 struct radv_device;
 struct radv_pipeline;
 struct radv_pipeline_cache;
+struct radv_pipeline_group_handle;
 struct radv_pipeline_key;
 struct radv_shader_args;
 struct radv_vs_input_state;
@@ -755,6 +756,7 @@ bool radv_lower_fs_intrinsics(nir_shader *nir, const struct 
radv_pipeline_stage
 nir_shader *create_rt_shader(struct radv_device *device,
                              const VkRayTracingPipelineCreateInfoKHR 
*pCreateInfo,
                              struct radv_pipeline_shader_stack_size 
*stack_sizes,
+                             const struct radv_pipeline_group_handle *handles,
                              const struct radv_pipeline_key *key);
 
 #endif

Reply via email to