Signed-off-by: Samuel Pitoiset <samuel.pitoi...@gmail.com>
---
 src/amd/vulkan/radv_shader.c | 30 +++++++++++++++++-------------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index 97fa80b348c..f0ab2d5e467 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -124,6 +124,17 @@ unsigned shader_io_get_unique_index(gl_varying_slot slot)
        unreachable("illegal slot in get unique index\n");
 }
 
+static uint8_t
+radv_get_shader_wave_size(const struct radv_physical_device *pdevice,
+                         gl_shader_stage stage)
+{
+       if (stage == MESA_SHADER_COMPUTE)
+               return pdevice->cs_wave_size;
+       else if (stage == MESA_SHADER_FRAGMENT)
+               return pdevice->ps_wave_size;
+       return pdevice->ge_wave_size;
+}
+
 VkResult radv_CreateShaderModule(
        VkDevice                                    _device,
        const VkShaderModuleCreateInfo*             pCreateInfo,
@@ -422,9 +433,13 @@ radv_shader_compile_to_nir(struct radv_device *device,
 
        nir_lower_global_vars_to_local(nir);
        nir_remove_dead_variables(nir, nir_var_function_temp);
+
+       uint8_t wave_size = radv_get_shader_wave_size(device->physical_device,
+                                                     nir->info.stage);
+
        nir_lower_subgroups(nir, &(struct nir_lower_subgroups_options) {
-                       .subgroup_size = 64,
-                       .ballot_bit_size = 64,
+                       .subgroup_size = wave_size,
+                       .ballot_bit_size = wave_size,
                        .lower_to_scalar = 1,
                        .lower_subgroup_masks = 1,
                        .lower_shuffle = 1,
@@ -667,17 +682,6 @@ radv_get_shader_binary_size(size_t code_size)
        return code_size + DEBUGGER_NUM_MARKERS * 4;
 }
 
-static uint8_t
-radv_get_shader_wave_size(const struct radv_physical_device *pdevice,
-                         gl_shader_stage stage)
-{
-       if (stage == MESA_SHADER_COMPUTE)
-               return pdevice->cs_wave_size;
-       else if (stage == MESA_SHADER_FRAGMENT)
-               return pdevice->ps_wave_size;
-       return pdevice->ge_wave_size;
-}
-
 static void radv_postprocess_config(const struct radv_physical_device *pdevice,
                                    const struct ac_shader_config *config_in,
                                    const struct radv_shader_variant_info *info,
-- 
2.22.0

_______________________________________________
mesa-dev mailing list
mesa-dev@lists.freedesktop.org
https://lists.freedesktop.org/mailman/listinfo/mesa-dev

Reply via email to