Module: Mesa
Branch: main
Commit: d0f7587109d42eb76099d242952ee1418c4c48e9
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=d0f7587109d42eb76099d242952ee1418c4c48e9

Author: Bas Nieuwenhuizen <[email protected]>
Date:   Wed Jan 11 02:32:27 2023 +0100

radv: Use group handles based on shader hashes.

Should be stable.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21406>

---

 src/amd/vulkan/radv_device.c      | 11 ++++-
 src/amd/vulkan/radv_pipeline_rt.c | 88 ++++++++++++++++++++++++++++++++-------
 src/amd/vulkan/radv_private.h     |  3 ++
 3 files changed, 87 insertions(+), 15 deletions(-)

diff --git a/src/amd/vulkan/radv_device.c b/src/amd/vulkan/radv_device.c
index c433d5068c0..1f4a12299c1 100644
--- a/src/amd/vulkan/radv_device.c
+++ b/src/amd/vulkan/radv_device.c
@@ -1815,7 +1815,7 @@ radv_GetPhysicalDeviceFeatures2(VkPhysicalDevice 
physicalDevice,
             (VkPhysicalDeviceRayTracingPipelineFeaturesKHR *)ext;
          features->rayTracingPipeline = true;
          features->rayTracingPipelineShaderGroupHandleCaptureReplay = true;
-         features->rayTracingPipelineShaderGroupHandleCaptureReplayMixed = 
true;
+         features->rayTracingPipelineShaderGroupHandleCaptureReplayMixed = 
false;
          features->rayTracingPipelineTraceRaysIndirect = true;
          features->rayTraversalPrimitiveCulling = true;
          break;
@@ -3914,6 +3914,9 @@ radv_CreateDevice(VkPhysicalDevice physicalDevice, const 
VkDeviceCreateInfo *pCr
    device->physical_device = physical_device;
    simple_mtx_init(&device->trace_mtx, mtx_plain);
    simple_mtx_init(&device->pstate_mtx, mtx_plain);
+   simple_mtx_init(&device->rt_handles_mtx, mtx_plain);
+
+   device->rt_handles = _mesa_hash_table_create(NULL, _mesa_hash_u32, 
_mesa_key_u32_equal);
 
    device->ws = physical_device->ws;
    vk_device_set_drm_fd(&device->vk, device->ws->get_fd(device->ws));
@@ -4251,8 +4254,11 @@ fail:
          device->ws->ctx_destroy(device->hw_ctx[i]);
    }
 
+   _mesa_hash_table_destroy(device->rt_handles, NULL);
+
    simple_mtx_destroy(&device->pstate_mtx);
    simple_mtx_destroy(&device->trace_mtx);
+   simple_mtx_destroy(&device->rt_handles_mtx);
    mtx_destroy(&device->overallocation_mutex);
 
    vk_device_finish(&device->vk);
@@ -4292,6 +4298,8 @@ radv_DestroyDevice(VkDevice _device, const 
VkAllocationCallbacks *pAllocator)
       vk_free(&device->vk.alloc, device->private_sdma_queue);
    }
 
+   _mesa_hash_table_destroy(device->rt_handles, NULL);
+
    for (unsigned i = 0; i < RADV_NUM_HW_CTX; i++) {
       if (device->hw_ctx[i])
          device->ws->ctx_destroy(device->hw_ctx[i]);
@@ -4300,6 +4308,7 @@ radv_DestroyDevice(VkDevice _device, const 
VkAllocationCallbacks *pAllocator)
    mtx_destroy(&device->overallocation_mutex);
    simple_mtx_destroy(&device->pstate_mtx);
    simple_mtx_destroy(&device->trace_mtx);
+   simple_mtx_destroy(&device->rt_handles_mtx);
 
    radv_device_finish_meta(device);
 
diff --git a/src/amd/vulkan/radv_pipeline_rt.c 
b/src/amd/vulkan/radv_pipeline_rt.c
index 197dcd9f6c6..19db7f69198 100644
--- a/src/amd/vulkan/radv_pipeline_rt.c
+++ b/src/amd/vulkan/radv_pipeline_rt.c
@@ -27,8 +27,61 @@
 #include "radv_private.h"
 #include "radv_shader.h"
 
+struct rt_handle_hash_entry {
+   uint32_t key;
+   char hash[20];
+};
+
+static uint32_t
+handle_from_stages(struct radv_device *device, const 
VkPipelineShaderStageCreateInfo *stages,
+                   unsigned stage_count, bool replay_namespace)
+{
+   struct mesa_sha1 ctx;
+   _mesa_sha1_init(&ctx);
+
+   radv_hash_rt_stages(&ctx, stages, stage_count);
+   unsigned char hash[20];
+   _mesa_sha1_final(&ctx, hash);
+
+   uint32_t ret;
+   memcpy(&ret, hash, sizeof(ret));
+
+   /* Leave the low half for resume shaders etc. */
+   ret |= 1u << 31;
+
+   /* Ensure we have dedicated space for replayable shaders */
+   ret &= ~(1u << 30);
+   ret |= replay_namespace << 30;
+
+   simple_mtx_lock(&device->rt_handles_mtx);
+
+   struct hash_entry *he = NULL;
+   for (;;) {
+      he = _mesa_hash_table_search(device->rt_handles, &ret);
+      if (!he)
+         break;
+
+      if (memcmp(he->data, hash, sizeof(hash)) == 0)
+         break;
+
+      ++ret;
+   }
+
+   if (!he) {
+      struct rt_handle_hash_entry *e = ralloc(device->rt_handles, struct 
rt_handle_hash_entry);
+      e->key = ret;
+      memcpy(e->hash, hash, sizeof(e->hash));
+      _mesa_hash_table_insert(device->rt_handles, &e->key, &e->hash);
+   }
+
+   simple_mtx_unlock(&device->rt_handles_mtx);
+
+   return ret;
+}
+
 static VkResult
-radv_create_group_handles(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
+radv_create_group_handles(struct radv_device *device,
+                          const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                           struct radv_pipeline_group_handle **out_handles)
 {
    struct radv_pipeline_group_handle *handles = calloc(sizeof(*handles), 
pCreateInfo->groupCount);
@@ -36,35 +89,42 @@ radv_create_group_handles(const 
VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
       return VK_ERROR_OUT_OF_HOST_MEMORY;
    }
 
-   /* For General and ClosestHit shaders, we can use the shader ID directly as 
handle.
-    * As (potentially different) AnyHit shaders are inlined, for Intersection 
shaders
-    * we use the Group ID.
-    */
+   bool capture_replay = pCreateInfo->flags &
+                         
VK_PIPELINE_CREATE_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR;
    for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
       const VkRayTracingShaderGroupCreateInfoKHR *group_info = 
&pCreateInfo->pGroups[i];
       switch (group_info->type) {
       case VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR:
          if (group_info->generalShader != VK_SHADER_UNUSED_KHR)
-            handles[i].general_index = group_info->generalShader + 2;
+            handles[i].general_index = handle_from_stages(
+               device, &pCreateInfo->pStages[group_info->generalShader], 1, 
capture_replay);
          break;
       case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR)
-            handles[i].closest_hit_index = group_info->closestHitShader + 2;
-         if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR)
-            handles[i].intersection_index = i + 2;
+            handles[i].closest_hit_index = handle_from_stages(
+               device, &pCreateInfo->pStages[group_info->closestHitShader], 1, 
capture_replay);
+         if (group_info->intersectionShader != VK_SHADER_UNUSED_KHR) {
+            VkPipelineShaderStageCreateInfo stages[2];
+            unsigned cnt = 0;
+            stages[cnt++] = 
pCreateInfo->pStages[group_info->intersectionShader];
+            if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
+               stages[cnt++] = pCreateInfo->pStages[group_info->anyHitShader];
+            handles[i].intersection_index = handle_from_stages(device, stages, 
cnt, capture_replay);
+         }
          break;
       case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
          if (group_info->closestHitShader != VK_SHADER_UNUSED_KHR)
-            handles[i].closest_hit_index = group_info->closestHitShader + 2;
+            handles[i].closest_hit_index = handle_from_stages(
+               device, &pCreateInfo->pStages[group_info->closestHitShader], 1, 
capture_replay);
          if (group_info->anyHitShader != VK_SHADER_UNUSED_KHR)
-            handles[i].any_hit_index = i + 2;
+            handles[i].any_hit_index = handle_from_stages(
+               device, &pCreateInfo->pStages[group_info->anyHitShader], 1, 
capture_replay);
          break;
       case VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR:
          unreachable("VK_SHADER_GROUP_SHADER_MAX_ENUM_KHR");
       }
 
-      if (pCreateInfo->flags &
-          
VK_PIPELINE_CREATE_RAY_TRACING_SHADER_GROUP_HANDLE_CAPTURE_REPLAY_BIT_KHR) {
+      if (capture_replay) {
          if (group_info->pShaderGroupCaptureReplayHandle &&
              memcmp(group_info->pShaderGroupCaptureReplayHandle, &handles[i], 
sizeof(handles[i])) !=
                 0) {
@@ -403,7 +463,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache 
_cache,
    radv_pipeline_init(device, &rt_pipeline->base.base, 
RADV_PIPELINE_RAY_TRACING);
    rt_pipeline->group_count = local_create_info.groupCount;
 
-   result = radv_create_group_handles(&local_create_info, 
&rt_pipeline->group_handles);
+   result = radv_create_group_handles(device, &local_create_info, 
&rt_pipeline->group_handles);
    if (result != VK_SUCCESS)
       goto pipeline_fail;
 
diff --git a/src/amd/vulkan/radv_private.h b/src/amd/vulkan/radv_private.h
index 87fca8b5d51..f0816afe1d8 100644
--- a/src/amd/vulkan/radv_private.h
+++ b/src/amd/vulkan/radv_private.h
@@ -1040,6 +1040,9 @@ struct radv_device {
    bool uses_device_generated_commands;
 
    bool uses_shadow_regs;
+
+   struct hash_table *rt_handles;
+   simple_mtx_t rt_handles_mtx;
 };
 
 bool radv_device_set_pstate(struct radv_device *device, bool enable);

Reply via email to