Module: Mesa
Branch: main
Commit: 2fb1097bac62bf93efe7b62629d79f9aa6e75cff
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=2fb1097bac62bf93efe7b62629d79f9aa6e75cff

Author: Qiang Yu <[email protected]>
Date:   Wed Nov 30 15:22:29 2022 +0800

ac/nir/ngg: merge multi stream gs shader queries

Before this commit each stream will emit a query block, now
we merge them to a single block.

Reviewed-by: Samuel Pitoiset <[email protected]>
Acked-by: Marek Olšák <[email protected]>
Signed-off-by: Qiang Yu <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20074>

---

 src/amd/common/ac_nir_lower_ngg.c | 78 +++++++++++++++++++++++++++------------
 1 file changed, 55 insertions(+), 23 deletions(-)

diff --git a/src/amd/common/ac_nir_lower_ngg.c 
b/src/amd/common/ac_nir_lower_ngg.c
index 5e6bbe8f953..a3e26586bed 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -120,7 +120,6 @@ typedef struct
    nir_ssa_def *lds_addr_gs_scratch;
    unsigned lds_bytes_per_gs_out_vertex;
    unsigned lds_offs_primflags;
-   bool found_out_vtxcnt[4];
    bool output_compile_time_known;
    bool streamout_enabled;
    /* 32 bit outputs */
@@ -131,6 +130,9 @@ typedef struct
    nir_variable *output_vars_16bit_lo[16][4];
    gs_output_info output_info_16bit_hi[16];
    gs_output_info output_info_16bit_lo[16];
+   /* Count per stream. */
+   nir_ssa_def *vertex_count[4];
+   nir_ssa_def *primitive_count[4];
 } lower_ngg_gs_state;
 
 /* LDS layout of Mesh Shader workgroup info. */
@@ -2390,7 +2392,7 @@ ngg_gs_clear_primflags(nir_builder *b, nir_ssa_def 
*num_vertices, unsigned strea
 }
 
 static void
-ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr *intrin, 
lower_ngg_gs_state *s)
+ngg_gs_shader_query(nir_builder *b, lower_ngg_gs_state *s)
 {
    bool has_gen_prim_query = s->options->has_gen_prim_query;
    bool has_pipeline_stats_query = s->options->gfx_level < GFX11;
@@ -2415,25 +2417,36 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr 
*intrin, lower_ngg_gs_st
    }
 
    nir_if *if_shader_query = nir_push_if(b, shader_query_enabled);
-   nir_ssa_def *num_prims_in_wave = NULL;
+
+   nir_ssa_def *active_threads_mask = nir_ballot(b, 1, s->options->wave_size, 
nir_imm_bool(b, true));
+   nir_ssa_def *num_active_threads = nir_bit_count(b, active_threads_mask);
 
    /* Calculate the "real" number of emitted primitives from the emitted GS 
vertices and primitives.
     * GS emits points, line strips or triangle strips.
     * Real primitives are points, lines or triangles.
     */
-   if (nir_src_is_const(intrin->src[0]) && nir_src_is_const(intrin->src[1])) {
-      unsigned gs_vtx_cnt = nir_src_as_uint(intrin->src[0]);
-      unsigned gs_prm_cnt = nir_src_as_uint(intrin->src[1]);
-      unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * 
(s->num_vertices_per_primitive - 1u);
-      nir_ssa_def *num_threads =
-         nir_bit_count(b, nir_ballot(b, 1, s->options->wave_size, 
nir_imm_bool(b, true)));
-      num_prims_in_wave = nir_imul_imm(b, num_threads, total_prm_cnt);
-   } else {
-      nir_ssa_def *gs_vtx_cnt = intrin->src[0].ssa;
-      nir_ssa_def *prm_cnt = intrin->src[1].ssa;
-      if (s->num_vertices_per_primitive > 1)
-         prm_cnt = nir_iadd(b, nir_imul_imm(b, prm_cnt, -1u * 
(s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
-      num_prims_in_wave = nir_reduce(b, prm_cnt, .reduction_op = nir_op_iadd);
+   nir_ssa_def *num_prims_in_wave[4] = {0};
+   u_foreach_bit (i, b->shader->info.gs.active_stream_mask) {
+      assert(s->vertex_count[i] && s->primitive_count[i]);
+
+      nir_ssa_scalar vtx_cnt = nir_get_ssa_scalar(s->vertex_count[i], 0);
+      nir_ssa_scalar prm_cnt = nir_get_ssa_scalar(s->primitive_count[i], 0);
+
+      if (nir_ssa_scalar_is_const(vtx_cnt) && 
nir_ssa_scalar_is_const(prm_cnt)) {
+         unsigned gs_vtx_cnt = nir_ssa_scalar_as_uint(vtx_cnt);
+         unsigned gs_prm_cnt = nir_ssa_scalar_as_uint(prm_cnt);
+         unsigned total_prm_cnt = gs_vtx_cnt - gs_prm_cnt * 
(s->num_vertices_per_primitive - 1u);
+         if (total_prm_cnt == 0)
+            continue;
+
+         num_prims_in_wave[i] = nir_imul_imm(b, num_active_threads, 
total_prm_cnt);
+      } else {
+         nir_ssa_def *gs_vtx_cnt = vtx_cnt.def;
+         nir_ssa_def *gs_prm_cnt = prm_cnt.def;
+         if (s->num_vertices_per_primitive > 1)
+            gs_prm_cnt = nir_iadd(b, nir_imul_imm(b, gs_prm_cnt, -1u * 
(s->num_vertices_per_primitive - 1)), gs_vtx_cnt);
+         num_prims_in_wave[i] = nir_reduce(b, gs_prm_cnt, .reduction_op = 
nir_op_iadd);
+      }
    }
 
    /* Store the query result to query result using an atomic add. */
@@ -2442,8 +2455,20 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr 
*intrin, lower_ngg_gs_st
       if (has_pipeline_stats_query) {
          nir_if *if_pipeline_query = nir_push_if(b, pipeline_query_enabled);
          {
+            nir_ssa_def *count = NULL;
+
             /* Add all streams' number to the same counter. */
-            nir_atomic_add_gs_emit_prim_count_amd(b, num_prims_in_wave);
+            for (int i = 0; i < 4; i++) {
+               if (num_prims_in_wave[i]) {
+                  if (count)
+                     count = nir_iadd(b, count, num_prims_in_wave[i]);
+                  else
+                     count = num_prims_in_wave[i];
+               }
+            }
+
+            if (count)
+               nir_atomic_add_gs_emit_prim_count_amd(b, count);
          }
          nir_pop_if(b, if_pipeline_query);
       }
@@ -2452,8 +2477,10 @@ ngg_gs_shader_query(nir_builder *b, nir_intrinsic_instr 
*intrin, lower_ngg_gs_st
          nir_if *if_prim_gen_query = nir_push_if(b, prim_gen_query_enabled);
          {
             /* Add to the counter for this stream. */
-            nir_atomic_add_gen_prim_count_amd(
-               b, num_prims_in_wave, .stream_id = 
nir_intrinsic_stream_id(intrin));
+            for (int i = 0; i < 4; i++) {
+               if (num_prims_in_wave[i])
+                  nir_atomic_add_gen_prim_count_amd(b, num_prims_in_wave[i], 
.stream_id = i);
+            }
          }
          nir_pop_if(b, if_prim_gen_query);
       }
@@ -2708,13 +2735,13 @@ lower_ngg_gs_set_vertex_and_primitive_count(nir_builder 
*b, nir_intrinsic_instr
       return true;
    }
 
-   s->found_out_vtxcnt[stream] = true;
+   s->vertex_count[stream] = intrin->src[0].ssa;
+   s->primitive_count[stream] = intrin->src[1].ssa;
 
    /* Clear the primitive flags of non-emitted vertices */
    if (!nir_src_is_const(intrin->src[0]) || nir_src_as_uint(intrin->src[0]) < 
b->shader->info.gs.vertices_out)
       ngg_gs_clear_primflags(b, intrin->src[0].ssa, stream, s);
 
-   ngg_gs_shader_query(b, intrin, s);
    nir_instr_remove(&intrin->instr);
    return true;
 }
@@ -3344,13 +3371,18 @@ ac_nir_lower_ngg_gs(nir_shader *shader, const 
ac_nir_lower_ngg_options *options)
 
    /* Lower the GS intrinsics */
    lower_ngg_gs_intrinsics(shader, &state);
-   b->cursor = nir_after_cf_list(&impl->body);
 
-   if (!state.found_out_vtxcnt[0]) {
+   if (!state.vertex_count[0]) {
       fprintf(stderr, "Could not find set_vertex_and_primitive_count for 
stream 0. This would hang your GPU.");
       abort();
    }
 
+   /* Emit shader queries */
+   b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
+   ngg_gs_shader_query(b, &state);
+
+   b->cursor = nir_after_cf_list(&impl->body);
+
    /* Emit the finale sequence */
    ngg_gs_finale(b, &state);
    nir_validate_shader(shader, "after emitting NGG GS");

Reply via email to