Module: Mesa
Branch: main
Commit: 8ec81a43cba3c71440a054627f9e94b2a6b9122f
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=8ec81a43cba3c71440a054627f9e94b2a6b9122f

Author: Daniel Schürmann <[email protected]>
Date:   Thu Mar 23 15:18:29 2023 +0100

radv/rt: use precompiled stages to create RT shader

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22100>

---

 src/amd/vulkan/radv_pipeline_rt.c | 20 +++++++++++++++++++-
 src/amd/vulkan/radv_rt_shader.c   | 31 ++++++++++++++++++-------------
 src/amd/vulkan/radv_shader.h      |  2 ++
 3 files changed, 39 insertions(+), 14 deletions(-)

diff --git a/src/amd/vulkan/radv_pipeline_rt.c 
b/src/amd/vulkan/radv_pipeline_rt.c
index 306ad5ca873..b4bea1220e6 100644
--- a/src/amd/vulkan/radv_pipeline_rt.c
+++ b/src/amd/vulkan/radv_pipeline_rt.c
@@ -643,6 +643,13 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache 
_cache,
    if (result != VK_SUCCESS)
       goto pipeline_fail;
 
+   struct radv_ray_tracing_stage *stages = 
calloc(local_create_info.stageCount, sizeof(*stages));
+   if (!stages) {
+      result = VK_ERROR_OUT_OF_HOST_MEMORY;
+      goto pipeline_fail;
+   }
+   radv_rt_fill_stage_info(pCreateInfo, stages);
+
    const VkPipelineCreationFeedbackCreateInfo *creation_feedback =
       vk_find_struct_const(pCreateInfo->pNext, 
PIPELINE_CREATION_FEEDBACK_CREATE_INFO);
 
@@ -664,7 +671,11 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache 
_cache,
       if (pCreateInfo->flags & 
VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT)
          goto pipeline_fail;
 
-      shader = create_rt_shader(device, &local_create_info, 
rt_pipeline->groups, &key);
+      result = radv_rt_precompile_shaders(device, cache, pCreateInfo, &key, 
stages);
+      if (result != VK_SUCCESS)
+         goto shader_fail;
+
+      shader = create_rt_shader(device, &local_create_info, stages, 
rt_pipeline->groups, &key);
       module.nir = shader;
       result = radv_rt_pipeline_compile(rt_pipeline, pipeline_layout, device, 
cache, &key, &stage,
                                         pCreateInfo->flags, hash, 
creation_feedback,
@@ -688,8 +699,15 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache 
_cache,
    radv_rmv_log_compute_pipeline_create(device, pCreateInfo->flags, 
&rt_pipeline->base.base, false);
 
    *pPipeline = radv_pipeline_to_handle(&rt_pipeline->base.base);
+
 shader_fail:
+   for (unsigned i = 0; stages && i < local_create_info.stageCount; i++) {
+      if (stages[i].shader)
+         vk_pipeline_cache_object_unref(&device->vk, stages[i].shader);
+   }
    ralloc_free(shader);
+   free(stages);
+
 pipeline_fail:
    if (result != VK_SUCCESS)
       radv_pipeline_destroy(device, &rt_pipeline->base.base, pAllocator);
diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c
index 52dfd24c408..19a5997927b 100644
--- a/src/amd/vulkan/radv_rt_shader.c
+++ b/src/amd/vulkan/radv_rt_shader.c
@@ -1190,6 +1190,7 @@ struct traversal_data {
    nir_variable *barycentrics;
 
    struct radv_ray_tracing_group *groups;
+   struct radv_ray_tracing_stage *stages;
    const struct radv_pipeline_key *key;
 };
 
@@ -1226,8 +1227,9 @@ visit_any_hit_shaders(struct radv_device *device,
       if (is_dup)
          continue;
 
-      const VkPipelineShaderStageCreateInfo *stage = 
&pCreateInfo->pStages[shader_id];
-      nir_shader *nir_stage = radv_parse_rt_stage(device, stage, data->key);
+      nir_shader *nir_stage =
+         radv_pipeline_cache_handle_to_nir(device, 
data->stages[shader_id].shader);
+      assert(nir_stage);
 
       insert_rt_case(b, nir_stage, vars, sbt_idx, 0, 
data->groups[i].handle.any_hit_index,
                      shader_id, data->groups);
@@ -1363,13 +1365,15 @@ handle_candidate_aabb(nir_builder *b, struct 
radv_leaf_intersection *intersectio
       if (is_dup)
          continue;
 
-      const VkPipelineShaderStageCreateInfo *stage = 
&data->createInfo->pStages[shader_id];
-      nir_shader *nir_stage = radv_parse_rt_stage(data->device, stage, 
data->key);
+      nir_shader *nir_stage =
+         radv_pipeline_cache_handle_to_nir(data->device, 
data->stages[shader_id].shader);
+      assert(nir_stage);
 
       nir_shader *any_hit_stage = NULL;
       if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
-         stage = &data->createInfo->pStages[any_hit_shader_id];
-         any_hit_stage = radv_parse_rt_stage(data->device, stage, data->key);
+         any_hit_stage =
+            radv_pipeline_cache_handle_to_nir(data->device, 
data->stages[any_hit_shader_id].shader);
+         assert(any_hit_stage);
 
          nir_lower_intersection_shader(nir_stage, any_hit_stage);
          ralloc_free(any_hit_stage);
@@ -1421,7 +1425,7 @@ load_stack_entry(nir_builder *b, nir_ssa_def *index, 
const struct radv_ray_trave
 }
 
 static nir_shader *
-build_traversal_shader(struct radv_device *device,
+build_traversal_shader(struct radv_device *device, struct 
radv_ray_tracing_stage *stages,
                        const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                        struct radv_ray_tracing_group *groups, const struct 
radv_pipeline_key *key)
 {
@@ -1515,6 +1519,7 @@ build_traversal_shader(struct radv_device *device,
       .trav_vars = &trav_vars,
       .barycentrics = barycentrics,
       .groups = groups,
+      .stages = stages,
       .key = key,
    };
 
@@ -1619,7 +1624,8 @@ move_rt_instructions(nir_shader *shader)
 
 nir_shader *
 create_rt_shader(struct radv_device *device, const 
VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
-                 struct radv_ray_tracing_group *groups, const struct 
radv_pipeline_key *key)
+                 struct radv_ray_tracing_stage *stages, struct 
radv_ray_tracing_group *groups,
+                 const struct radv_pipeline_key *key)
 {
    nir_builder b = radv_meta_init_shader(device, MESA_SHADER_RAYGEN, 
"rt_combined");
    b.shader->info.internal = false;
@@ -1635,7 +1641,7 @@ create_rt_shader(struct radv_device *device, const 
VkRayTracingPipelineCreateInf
    nir_ssa_def *idx = nir_load_var(&b, vars.idx);
 
    /* Insert traversal shader */
-   nir_shader *traversal = build_traversal_shader(device, pCreateInfo, groups, 
key);
+   nir_shader *traversal = build_traversal_shader(device, stages, pCreateInfo, 
groups, key);
    b.shader->info.shared_size = MAX2(b.shader->info.shared_size, 
traversal->info.shared_size);
    assert(b.shader->info.shared_size <= 32768);
    insert_rt_case(&b, traversal, &vars, idx, 0, 1, -1u, groups);
@@ -1657,13 +1663,12 @@ create_rt_shader(struct radv_device *device, const 
VkRayTracingPipelineCreateInf
       if (is_dup)
          continue;
 
-      const VkPipelineShaderStageCreateInfo *stage = 
&pCreateInfo->pStages[stage_idx];
-      ASSERTED gl_shader_stage type = vk_to_mesa_shader_stage(stage->stage);
+      nir_shader *nir_stage = radv_pipeline_cache_handle_to_nir(device, 
stages[stage_idx].shader);
+      assert(nir_stage);
+      ASSERTED gl_shader_stage type = nir_stage->info.stage;
       assert(type == MESA_SHADER_RAYGEN || type == MESA_SHADER_CALLABLE ||
              type == MESA_SHADER_CLOSEST_HIT || type == MESA_SHADER_MISS);
 
-      nir_shader *nir_stage = radv_parse_rt_stage(device, stage, key);
-
       /* Move ray tracing system values to the top that are set by rt_trace_ray
        * to prevent them from being overwritten by other rt_trace_ray calls.
        */
diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h
index 4559c74b068..96fc1422819 100644
--- a/src/amd/vulkan/radv_shader.h
+++ b/src/amd/vulkan/radv_shader.h
@@ -43,6 +43,7 @@
 struct radv_physical_device;
 struct radv_device;
 struct radv_pipeline;
+struct radv_ray_tracing_stage;
 struct radv_ray_tracing_group;
 struct radv_pipeline_key;
 struct radv_shader_args;
@@ -765,6 +766,7 @@ void radv_get_nir_options(struct radv_physical_device 
*device);
 
 nir_shader *create_rt_shader(struct radv_device *device,
                              const VkRayTracingPipelineCreateInfoKHR 
*pCreateInfo,
+                             struct radv_ray_tracing_stage *stages,
                              struct radv_ray_tracing_group *groups,
                              const struct radv_pipeline_key *key);
 

Reply via email to