Module: Mesa
Branch: main
Commit: cddbe9a4c238a1ac4202b7d099b5225e9aef9d87
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=cddbe9a4c238a1ac4202b7d099b5225e9aef9d87

Author: Rhys Perry <[email protected]>
Date:   Fri Sep  1 11:24:56 2023 +0100

ac/nir: implement mesh shader gs_fast_launch=2

Signed-off-by: Rhys Perry <[email protected]>
Reviewed-by: Timur Kristóf <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25040>

---

 src/amd/common/ac_nir.h           |  3 ++-
 src/amd/common/ac_nir_lower_ngg.c | 20 ++++++++++++++------
 src/amd/vulkan/radv_shader.c      |  2 +-
 3 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h
index dc0651a5fc3..0d91a7edb2c 100644
--- a/src/amd/common/ac_nir.h
+++ b/src/amd/common/ac_nir.h
@@ -197,7 +197,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
                     bool *out_needs_scratch_ring,
                     unsigned wave_size,
                     bool multiview,
-                    bool has_query);
+                    bool has_query,
+                    bool fast_launch_2);
 
 void
 ac_nir_lower_task_outputs_to_mem(nir_shader *shader,
diff --git a/src/amd/common/ac_nir_lower_ngg.c 
b/src/amd/common/ac_nir_lower_ngg.c
index e624a5c7a20..2257415c845 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -197,6 +197,7 @@ typedef struct
 typedef struct
 {
    enum amd_gfx_level gfx_level;
+   bool fast_launch_2;
 
    ms_out_mem_layout layout;
    uint64_t per_vertex_outputs;
@@ -4513,6 +4514,10 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
 
    ms_prim_gen_query(b, invocation_index, num_prm, s);
 
+   nir_def *row_start = NULL;
+   if (s->fast_launch_2)
+      row_start = s->hw_workgroup_size <= s->wave_size ? nir_imm_int(b, 0) : 
nir_load_subgroup_id(b);
+
    /* Load vertex/primitive attributes from shared memory and
     * emit store_output intrinsics for them.
     *
@@ -4544,7 +4549,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
       nir_def *has_output_vertex = nir_ilt(b, invocation_index, num_vtx);
       nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
       {
-         emit_ms_vertex(b, invocation_index, NULL, !wait_attr_ring, true, 
per_vertex_outputs, s);
+         emit_ms_vertex(b, invocation_index, row_start, !wait_attr_ring, true, 
per_vertex_outputs, s);
       }
       nir_pop_if(b, if_has_output_vertex);
    }
@@ -4554,7 +4559,7 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
       nir_def *has_output_primitive = nir_ilt(b, invocation_index, num_prm);
       nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
       {
-         emit_ms_primitive(b, invocation_index, NULL, !wait_attr_ring, true, 
per_primitive_outputs, s);
+         emit_ms_primitive(b, invocation_index, row_start, !wait_attr_ring, 
true, per_primitive_outputs, s);
       }
       nir_pop_if(b, if_has_output_primitive);
    }
@@ -4574,14 +4579,14 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
       nir_def *has_output_vertex = nir_ilt(b, invocation_index, num_vtx);
       nir_if *if_has_output_vertex = nir_push_if(b, has_output_vertex);
       {
-         emit_ms_vertex(b, invocation_index, NULL, true, false, 
per_vertex_outputs, s);
+         emit_ms_vertex(b, invocation_index, row_start, true, false, 
per_vertex_outputs, s);
       }
       nir_pop_if(b, if_has_output_vertex);
 
       nir_def *has_output_primitive = nir_ilt(b, invocation_index, num_prm);
       nir_if *if_has_output_primitive = nir_push_if(b, has_output_primitive);
       {
-         emit_ms_primitive(b, invocation_index, NULL, true, false, 
per_primitive_outputs, s);
+         emit_ms_primitive(b, invocation_index, row_start, true, false, 
per_primitive_outputs, s);
       }
       nir_pop_if(b, if_has_output_primitive);
    }
@@ -4866,7 +4871,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
                     bool *out_needs_scratch_ring,
                     unsigned wave_size,
                     bool multiview,
-                    bool has_query)
+                    bool has_query,
+                    bool fast_launch_2)
 {
    unsigned vertices_per_prim =
       num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type);
@@ -4917,6 +4923,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
       .insert_layer_output = multiview && !(shader->info.outputs_written & 
VARYING_BIT_LAYER),
       .uses_cull_flags = uses_cull,
       .gfx_level = gfx_level,
+      .fast_launch_2 = fast_launch_2,
       .clipdist_enable_mask = clipdist_enable_mask,
       .vs_output_param_offset = vs_output_param_offset,
       .has_param_exports = has_param_exports,
@@ -4935,7 +4942,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
    nir_builder *b = &builder; /* This is to avoid the & */
 
    handle_smaller_ms_api_workgroup(b, &state);
-   ms_emit_legacy_workgroup_index(b, &state);
+   if (!fast_launch_2)
+      ms_emit_legacy_workgroup_index(b, &state);
    ms_create_same_invocation_vars(b, &state);
    nir_metadata_preserve(impl, nir_metadata_none);
 
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index 2903ad2ba2e..5aac387cac9 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -916,7 +916,7 @@ radv_lower_ngg(struct radv_device *device, struct 
radv_shader_stage *ngg_stage,
       bool scratch_ring = false;
       NIR_PASS_V(nir, ac_nir_lower_ngg_ms, options.gfx_level, 
options.clipdist_enable_mask,
                  options.vs_output_param_offset, options.has_param_exports, 
&scratch_ring, info->wave_size,
-                 pl_key->has_multiview_view_index, info->ms.has_query);
+                 pl_key->has_multiview_view_index, info->ms.has_query, false);
       ngg_stage->info.ms.needs_ms_scratch_ring = scratch_ring;
    } else {
       unreachable("invalid SW stage passed to radv_lower_ngg");

Reply via email to