Module: Mesa
Branch: main
Commit: 1f878334c0369f0f5c7b64df9ccd31c646dcc182
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=1f878334c0369f0f5c7b64df9ccd31c646dcc182

Author: Samuel Pitoiset <[email protected]>
Date:   Fri Mar 24 16:12:16 2023 +0100

radv: add radv_bind_shader() helper

Signed-off-by: Samuel Pitoiset <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22128>

---

 src/amd/vulkan/radv_cmd_buffer.c | 86 +++++++++++++++++++++++++---------------
 1 file changed, 54 insertions(+), 32 deletions(-)

diff --git a/src/amd/vulkan/radv_cmd_buffer.c b/src/amd/vulkan/radv_cmd_buffer.c
index 202e0b79444..c1c12c57156 100644
--- a/src/amd/vulkan/radv_cmd_buffer.c
+++ b/src/amd/vulkan/radv_cmd_buffer.c
@@ -6367,6 +6367,48 @@ radv_bind_task_shader(struct radv_cmd_buffer 
*cmd_buffer, const struct radv_shad
    cmd_buffer->task_rings_needed = true;
 }
 
+/* This function binds/unbinds a shader to the cmdbuffer state. */
+static void
+radv_bind_shader(struct radv_cmd_buffer *cmd_buffer, struct radv_shader 
*shader,
+                 gl_shader_stage stage)
+{
+   if (!shader)
+      return;
+
+   switch (stage) {
+   case MESA_SHADER_VERTEX:
+      radv_bind_vertex_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_TESS_CTRL:
+      radv_bind_tess_ctrl_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_TESS_EVAL:
+      radv_bind_tess_eval_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_GEOMETRY:
+      radv_bind_geometry_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_FRAGMENT:
+      radv_bind_fragment_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_MESH:
+      radv_bind_mesh_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_TASK:
+      radv_bind_task_shader(cmd_buffer, shader);
+      break;
+   case MESA_SHADER_COMPUTE:
+   case MESA_SHADER_RAYGEN:
+      /* no-op */
+      break;
+   default:
+      unreachable("invalid shader stage");
+   }
+}
+
+#define RADV_GRAPHICS_STAGES \
+   (VK_SHADER_STAGE_ALL_GRAPHICS | VK_SHADER_STAGE_MESH_BIT_EXT | 
VK_SHADER_STAGE_TASK_BIT_EXT)
+
 VKAPI_ATTR void VKAPI_CALL
 radv_CmdBindPipeline(VkCommandBuffer commandBuffer, VkPipelineBindPoint 
pipelineBindPoint,
                      VkPipeline _pipeline)
@@ -6382,6 +6424,9 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, 
VkPipelineBindPoint pipeline
          return;
       radv_mark_descriptor_sets_dirty(cmd_buffer, pipelineBindPoint);
 
+      radv_bind_shader(cmd_buffer, 
compute_pipeline->base.shaders[MESA_SHADER_COMPUTE],
+                       MESA_SHADER_COMPUTE);
+
       cmd_buffer->state.compute_pipeline = compute_pipeline;
       cmd_buffer->push_constant_stages |= VK_SHADER_STAGE_COMPUTE_BIT;
       break;
@@ -6393,6 +6438,11 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, 
VkPipelineBindPoint pipeline
          return;
       radv_mark_descriptor_sets_dirty(cmd_buffer, pipelineBindPoint);
 
+      radv_bind_shader(cmd_buffer, 
rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE],
+                       MESA_SHADER_COMPUTE);
+      radv_bind_shader(cmd_buffer, 
rt_pipeline->base.base.shaders[MESA_SHADER_RAYGEN],
+                       MESA_SHADER_RAYGEN);
+
       cmd_buffer->state.rt_pipeline = rt_pipeline;
       cmd_buffer->push_constant_stages |= RADV_RT_STAGE_BITS;
 
@@ -6408,6 +6458,10 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, 
VkPipelineBindPoint pipeline
          return;
       radv_mark_descriptor_sets_dirty(cmd_buffer, pipelineBindPoint);
 
+      radv_foreach_stage(stage, RADV_GRAPHICS_STAGES) {
+         radv_bind_shader(cmd_buffer, graphics_pipeline->base.shaders[stage], 
stage);
+      }
+
       bool vtx_emit_count_changed =
          !cmd_buffer->state.graphics_pipeline ||
          cmd_buffer->state.graphics_pipeline->vtx_emit_num != 
graphics_pipeline->vtx_emit_num ||
@@ -6480,38 +6534,6 @@ radv_CmdBindPipeline(VkCommandBuffer commandBuffer, 
VkPipelineBindPoint pipeline
          MAX2(cmd_buffer->scratch_size_per_wave_needed, 
pipeline->scratch_bytes_per_wave);
       cmd_buffer->scratch_waves_wanted = 
MAX2(cmd_buffer->scratch_waves_wanted, pipeline->max_waves);
 
-      for (uint32_t s = 0; s < MESA_SHADER_COMPUTE; s++) {
-         const struct radv_shader *shader = graphics_pipeline->base.shaders[s];
-
-         if (!shader)
-            continue;
-
-         switch (s) {
-         case MESA_SHADER_VERTEX:
-            radv_bind_vertex_shader(cmd_buffer, shader);
-            break;
-         case MESA_SHADER_TESS_CTRL:
-            radv_bind_tess_ctrl_shader(cmd_buffer, shader);
-            break;
-         case MESA_SHADER_TESS_EVAL:
-            radv_bind_tess_eval_shader(cmd_buffer, shader);
-            break;
-         case MESA_SHADER_GEOMETRY:
-            radv_bind_geometry_shader(cmd_buffer, shader);
-            break;
-         case MESA_SHADER_FRAGMENT:
-            radv_bind_fragment_shader(cmd_buffer, shader);
-            break;
-         default:
-            unreachable("invalid graphics shader stage");
-         }
-      }
-
-      if (graphics_pipeline->base.shaders[MESA_SHADER_MESH])
-         radv_bind_mesh_shader(cmd_buffer, 
graphics_pipeline->base.shaders[MESA_SHADER_MESH]);
-      if (graphics_pipeline->base.shaders[MESA_SHADER_TASK])
-         radv_bind_task_shader(cmd_buffer, 
graphics_pipeline->base.shaders[MESA_SHADER_TASK]);
-
       radv_bind_multisample_state(cmd_buffer, &graphics_pipeline->ms);
       break;
    }

Reply via email to