Module: Mesa
Branch: main
Commit: 074f3216f207cbe7b56b57cb909b2c108c8a1e7a
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=074f3216f207cbe7b56b57cb909b2c108c8a1e7a

Author: Qiang Yu <[email protected]>
Date:   Thu Jun 30 16:10:53 2022 +0800

ac/nir/ngg: support gs streamout

Port from radeonsi.

Reviewed-by: Timur Kristóf <[email protected]>
Signed-off-by: Qiang Yu <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17654>

---

 src/amd/common/ac_nir.h           |   3 +-
 src/amd/common/ac_nir_lower_ngg.c | 126 ++++++++++++++++++++++++++++++++++++--
 src/amd/vulkan/radv_shader.c      |   3 +-
 3 files changed, 125 insertions(+), 7 deletions(-)

diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h
index bfddb2a6ba3..ac4351f324e 100644
--- a/src/amd/common/ac_nir.h
+++ b/src/amd/common/ac_nir.h
@@ -144,7 +144,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
                     unsigned gs_out_vtx_bytes,
                     unsigned gs_total_out_vtx_bytes,
                     bool provoking_vtx_last,
-                    bool can_cull);
+                    bool can_cull,
+                    bool disable_streamout);
 
 void
 ac_nir_lower_ngg_ms(nir_shader *shader,
diff --git a/src/amd/common/ac_nir_lower_ngg.c 
b/src/amd/common/ac_nir_lower_ngg.c
index c7af6517a53..1965ef96e6c 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -112,6 +112,7 @@ typedef struct
    bool output_compile_time_known;
    bool provoking_vertex_last;
    bool can_cull;
+   bool streamout_enabled;
    gs_output_info output_info[VARYING_SLOT_MAX];
 } lower_ngg_gs_state;
 
@@ -2572,6 +2573,110 @@ ngg_gs_cull_primitive(nir_builder *b, nir_ssa_def 
*tid_in_tg, nir_ssa_def *max_v
    return nir_load_var(b, primflag_var);
 }
 
+static void
+ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *st)
+{
+   nir_xfb_info *info = nir_gather_xfb_info_from_intrinsics(b->shader, NULL);
+   if (unlikely(!info))
+      return;
+
+   nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b);
+   nir_ssa_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b);
+   nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, st);
+   nir_ssa_def *prim_live[4] = {0};
+   nir_ssa_def *gen_prim[4] = {0};
+   nir_ssa_def *export_seq[4] = {0};
+   nir_ssa_def *out_vtx_primflag[4] = {0};
+   for (unsigned stream = 0; stream < 4; stream++) {
+      if (!(info->streams_written & BITFIELD_BIT(stream)))
+         continue;
+
+      out_vtx_primflag[stream] =
+         ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, 
max_vtxcnt, st);
+
+      /* Check bit 0 of primflag for primitive alive, it's set for every last
+       * vertex of a primitive.
+       */
+      prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 
1));
+
+      unsigned scratch_stride = ALIGN(st->max_num_waves, 4);
+
+      /* We want to export primitives to streamout buffer in sequence,
+       * but not all vertices are alive or mark end of a primitive, so
+       * there're "holes". We don't need continous invocations to write
+       * primitives to streamout buffer like final vertex export, so
+       * just repack to get the sequence (export_seq) is enough, no need
+       * to do compaction.
+       *
+       * Use separate scratch space for each stream to avoid barrier.
+       * TODO: we may further reduce barriers by writing to all stream
+       * LDS at once, then we only need one barrier instead of one each
+       * stream..
+       */
+      wg_repack_result rep =
+         repack_invocations_in_workgroup(b, prim_live[stream],
+                                         st->lds_addr_gs_scratch + stream * 
scratch_stride,
+                                         st->max_num_waves, st->wave_size);
+
+      /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive 
count of
+       * current wave, but still need LDS to sum all wave's count to get 
workgroup count.
+       * And we need repack to export primitive to streamout buffer anyway, so 
do here.
+       */
+      gen_prim[stream] = rep.num_repacked_invocations;
+      export_seq[stream] = rep.repacked_invocation_index;
+   }
+
+   /* Workgroup barrier: wait for LDS scratch reads finish. */
+   nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP,
+                      .memory_scope = NIR_SCOPE_WORKGROUP,
+                      .memory_semantics = NIR_MEMORY_ACQ_REL,
+                      .memory_modes = nir_var_mem_shared);
+
+   /* Get global buffer offset where this workgroup will stream out data to. */
+   nir_ssa_def *emit_prim[4] = {0};
+   nir_ssa_def *buffer_offsets[4] = {0};
+   nir_ssa_def *so_buffer[4] = {0};
+   nir_ssa_def *prim_stride[4] = {0};
+   ngg_build_streamout_buffer_info(b, info, st->lds_addr_gs_scratch, 
tid_in_tg, gen_prim,
+                                   prim_stride, so_buffer, buffer_offsets, 
emit_prim);
+
+   /* GS use packed location for vertex LDS storage. */
+   int slot_to_register[NUM_TOTAL_VARYING_SLOTS];
+   for (int i = 0; i < info->output_count; i++) {
+      unsigned location = info->outputs[i].location;
+      slot_to_register[location] =
+         util_bitcount64(b->shader->info.outputs_written & 
BITFIELD64_MASK(location));
+   }
+
+   for (unsigned stream = 0; stream < 4; stream++) {
+      if (!(info->streams_written & BITFIELD_BIT(stream)))
+         continue;
+
+      nir_ssa_def *can_emit = nir_ilt(b, export_seq[stream], 
emit_prim[stream]);
+      nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, 
prim_live[stream]));
+      {
+         /* Get streamout buffer vertex index for the first vertex of this 
primitive. */
+         nir_ssa_def *vtx_buffer_idx =
+            nir_imul_imm(b, export_seq[stream], 
st->num_vertices_per_primitive);
+
+         /* Get all vertices' lds address of this primitive. */
+         nir_ssa_def *exported_vtx_lds_addr[3];
+         ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr,
+                                    out_vtx_primflag[stream], st,
+                                    exported_vtx_lds_addr);
+
+         /* Write all vertices of this primitive to streamout buffer. */
+         for (unsigned i = 0; i < st->num_vertices_per_primitive; i++) {
+            ngg_build_streamout_vertex(b, info, stream, slot_to_register,
+                                       so_buffer, buffer_offsets,
+                                       nir_iadd_imm(b, vtx_buffer_idx, i),
+                                       exported_vtx_lds_addr[i]);
+         }
+      }
+      nir_pop_if(b, if_emit);
+   }
+}
+
 static void
 ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
 {
@@ -2589,9 +2694,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s)
       nir_pop_if(b, if_wave_0);
    }
 
-   /* Workgroup barrier: wait for all GS threads to finish */
-   nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, 
.memory_scope=NIR_SCOPE_WORKGROUP,
-                         .memory_semantics=NIR_MEMORY_ACQ_REL, 
.memory_modes=nir_var_mem_shared);
+   /* Workgroup barrier already emitted, we can assume all GS output stores 
are done by now. */
 
    nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, 
tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s);
 
@@ -2654,7 +2757,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
                     unsigned gs_out_vtx_bytes,
                     unsigned gs_total_out_vtx_bytes,
                     bool provoking_vertex_last,
-                    bool can_cull)
+                    bool can_cull,
+                    bool disable_streamout)
 {
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
    assert(impl);
@@ -2669,9 +2773,14 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
       .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u,
       .provoking_vertex_last = provoking_vertex_last,
       .can_cull = can_cull,
+      .streamout_enabled = shader->xfb_info && !disable_streamout,
    };
 
-   unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u;
+   unsigned lds_scratch_bytes = ALIGN(state.max_num_waves, 4u);
+   /* streamout take 8 dwords for buffer offset and emit vertex per stream */
+   if (state.streamout_enabled)
+      lds_scratch_bytes = MAX2(lds_scratch_bytes, 32);
+
    unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes;
    shader->info.shared_size = total_lds_bytes;
 
@@ -2715,6 +2824,13 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
    nir_pop_if(b, if_gs_thread);
 
+   /* Workgroup barrier: wait for all GS threads to finish */
+   nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, 
.memory_scope=NIR_SCOPE_WORKGROUP,
+                         .memory_semantics=NIR_MEMORY_ACQ_REL, 
.memory_modes=nir_var_mem_shared);
+
+   if (state.streamout_enabled)
+      ngg_gs_build_streamout(b, &state);
+
    /* Lower the GS intrinsics */
    lower_ngg_gs_intrinsics(shader, &state);
    b->cursor = nir_after_cf_list(&impl->body);
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index e9145443bb4..5a7e49ef6f4 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -1341,7 +1341,8 @@ void radv_lower_ngg(struct radv_device *device, struct 
radv_pipeline_stage *ngg_
       assert(info->is_ngg);
       NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, 
info->workgroup_size,
                  info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
-                 info->ngg_info.ngg_emit_size * 4u, 
pl_key->vs.provoking_vtx_last, false);
+                 info->ngg_info.ngg_emit_size * 4u, 
pl_key->vs.provoking_vtx_last,
+                 false, true);
    } else if (nir->info.stage == MESA_SHADER_MESH) {
       bool scratch_ring = false;
       NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, 
pl_key->has_multiview_view_index);

Reply via email to