Module: Mesa Branch: main Commit: 074f3216f207cbe7b56b57cb909b2c108c8a1e7a URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=074f3216f207cbe7b56b57cb909b2c108c8a1e7a
Author: Qiang Yu <[email protected]> Date: Thu Jun 30 16:10:53 2022 +0800 ac/nir/ngg: support gs streamout Port from radeonsi. Reviewed-by: Timur Kristóf <[email protected]> Signed-off-by: Qiang Yu <[email protected]> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/17654> --- src/amd/common/ac_nir.h | 3 +- src/amd/common/ac_nir_lower_ngg.c | 126 ++++++++++++++++++++++++++++++++++++-- src/amd/vulkan/radv_shader.c | 3 +- 3 files changed, 125 insertions(+), 7 deletions(-) diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h index bfddb2a6ba3..ac4351f324e 100644 --- a/src/amd/common/ac_nir.h +++ b/src/amd/common/ac_nir.h @@ -144,7 +144,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader, unsigned gs_out_vtx_bytes, unsigned gs_total_out_vtx_bytes, bool provoking_vtx_last, - bool can_cull); + bool can_cull, + bool disable_streamout); void ac_nir_lower_ngg_ms(nir_shader *shader, diff --git a/src/amd/common/ac_nir_lower_ngg.c b/src/amd/common/ac_nir_lower_ngg.c index c7af6517a53..1965ef96e6c 100644 --- a/src/amd/common/ac_nir_lower_ngg.c +++ b/src/amd/common/ac_nir_lower_ngg.c @@ -112,6 +112,7 @@ typedef struct bool output_compile_time_known; bool provoking_vertex_last; bool can_cull; + bool streamout_enabled; gs_output_info output_info[VARYING_SLOT_MAX]; } lower_ngg_gs_state; @@ -2572,6 +2573,110 @@ ngg_gs_cull_primitive(nir_builder *b, nir_ssa_def *tid_in_tg, nir_ssa_def *max_v return nir_load_var(b, primflag_var); } +static void +ngg_gs_build_streamout(nir_builder *b, lower_ngg_gs_state *st) +{ + nir_xfb_info *info = nir_gather_xfb_info_from_intrinsics(b->shader, NULL); + if (unlikely(!info)) + return; + + nir_ssa_def *tid_in_tg = nir_load_local_invocation_index(b); + nir_ssa_def *max_vtxcnt = nir_load_workgroup_num_input_vertices_amd(b); + nir_ssa_def *out_vtx_lds_addr = ngg_gs_out_vertex_addr(b, tid_in_tg, st); + nir_ssa_def *prim_live[4] = {0}; + nir_ssa_def *gen_prim[4] = {0}; + nir_ssa_def *export_seq[4] = {0}; + nir_ssa_def *out_vtx_primflag[4] = {0}; + for (unsigned stream = 0; stream < 4; stream++) { + if (!(info->streams_written & BITFIELD_BIT(stream))) + continue; + + out_vtx_primflag[stream] = + ngg_gs_load_out_vtx_primflag(b, stream, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, st); + + /* Check bit 0 of primflag for primitive alive, it's set for every last + * vertex of a primitive. + */ + prim_live[stream] = nir_i2b(b, nir_iand_imm(b, out_vtx_primflag[stream], 1)); + + unsigned scratch_stride = ALIGN(st->max_num_waves, 4); + + /* We want to export primitives to streamout buffer in sequence, + * but not all vertices are alive or mark end of a primitive, so + * there're "holes". We don't need continous invocations to write + * primitives to streamout buffer like final vertex export, so + * just repack to get the sequence (export_seq) is enough, no need + * to do compaction. + * + * Use separate scratch space for each stream to avoid barrier. + * TODO: we may further reduce barriers by writing to all stream + * LDS at once, then we only need one barrier instead of one each + * stream.. + */ + wg_repack_result rep = + repack_invocations_in_workgroup(b, prim_live[stream], + st->lds_addr_gs_scratch + stream * scratch_stride, + st->max_num_waves, st->wave_size); + + /* nir_intrinsic_set_vertex_and_primitive_count can also get primitive count of + * current wave, but still need LDS to sum all wave's count to get workgroup count. + * And we need repack to export primitive to streamout buffer anyway, so do here. + */ + gen_prim[stream] = rep.num_repacked_invocations; + export_seq[stream] = rep.repacked_invocation_index; + } + + /* Workgroup barrier: wait for LDS scratch reads finish. */ + nir_scoped_barrier(b, .execution_scope = NIR_SCOPE_WORKGROUP, + .memory_scope = NIR_SCOPE_WORKGROUP, + .memory_semantics = NIR_MEMORY_ACQ_REL, + .memory_modes = nir_var_mem_shared); + + /* Get global buffer offset where this workgroup will stream out data to. */ + nir_ssa_def *emit_prim[4] = {0}; + nir_ssa_def *buffer_offsets[4] = {0}; + nir_ssa_def *so_buffer[4] = {0}; + nir_ssa_def *prim_stride[4] = {0}; + ngg_build_streamout_buffer_info(b, info, st->lds_addr_gs_scratch, tid_in_tg, gen_prim, + prim_stride, so_buffer, buffer_offsets, emit_prim); + + /* GS use packed location for vertex LDS storage. */ + int slot_to_register[NUM_TOTAL_VARYING_SLOTS]; + for (int i = 0; i < info->output_count; i++) { + unsigned location = info->outputs[i].location; + slot_to_register[location] = + util_bitcount64(b->shader->info.outputs_written & BITFIELD64_MASK(location)); + } + + for (unsigned stream = 0; stream < 4; stream++) { + if (!(info->streams_written & BITFIELD_BIT(stream))) + continue; + + nir_ssa_def *can_emit = nir_ilt(b, export_seq[stream], emit_prim[stream]); + nir_if *if_emit = nir_push_if(b, nir_iand(b, can_emit, prim_live[stream])); + { + /* Get streamout buffer vertex index for the first vertex of this primitive. */ + nir_ssa_def *vtx_buffer_idx = + nir_imul_imm(b, export_seq[stream], st->num_vertices_per_primitive); + + /* Get all vertices' lds address of this primitive. */ + nir_ssa_def *exported_vtx_lds_addr[3]; + ngg_gs_out_prim_all_vtxptr(b, tid_in_tg, out_vtx_lds_addr, + out_vtx_primflag[stream], st, + exported_vtx_lds_addr); + + /* Write all vertices of this primitive to streamout buffer. */ + for (unsigned i = 0; i < st->num_vertices_per_primitive; i++) { + ngg_build_streamout_vertex(b, info, stream, slot_to_register, + so_buffer, buffer_offsets, + nir_iadd_imm(b, vtx_buffer_idx, i), + exported_vtx_lds_addr[i]); + } + } + nir_pop_if(b, if_emit); + } +} + static void ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s) { @@ -2589,9 +2694,7 @@ ngg_gs_finale(nir_builder *b, lower_ngg_gs_state *s) nir_pop_if(b, if_wave_0); } - /* Workgroup barrier: wait for all GS threads to finish */ - nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, - .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); + /* Workgroup barrier already emitted, we can assume all GS output stores are done by now. */ nir_ssa_def *out_vtx_primflag_0 = ngg_gs_load_out_vtx_primflag(b, 0, tid_in_tg, out_vtx_lds_addr, max_vtxcnt, s); @@ -2654,7 +2757,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader, unsigned gs_out_vtx_bytes, unsigned gs_total_out_vtx_bytes, bool provoking_vertex_last, - bool can_cull) + bool can_cull, + bool disable_streamout) { nir_function_impl *impl = nir_shader_get_entrypoint(shader); assert(impl); @@ -2669,9 +2773,14 @@ ac_nir_lower_ngg_gs(nir_shader *shader, .lds_bytes_per_gs_out_vertex = gs_out_vtx_bytes + 4u, .provoking_vertex_last = provoking_vertex_last, .can_cull = can_cull, + .streamout_enabled = shader->xfb_info && !disable_streamout, }; - unsigned lds_scratch_bytes = DIV_ROUND_UP(state.max_num_waves, 4u) * 4u; + unsigned lds_scratch_bytes = ALIGN(state.max_num_waves, 4u); + /* streamout take 8 dwords for buffer offset and emit vertex per stream */ + if (state.streamout_enabled) + lds_scratch_bytes = MAX2(lds_scratch_bytes, 32); + unsigned total_lds_bytes = state.lds_addr_gs_scratch + lds_scratch_bytes; shader->info.shared_size = total_lds_bytes; @@ -2715,6 +2824,13 @@ ac_nir_lower_ngg_gs(nir_shader *shader, b->cursor = nir_after_cf_list(&if_gs_thread->then_list); nir_pop_if(b, if_gs_thread); + /* Workgroup barrier: wait for all GS threads to finish */ + nir_scoped_barrier(b, .execution_scope=NIR_SCOPE_WORKGROUP, .memory_scope=NIR_SCOPE_WORKGROUP, + .memory_semantics=NIR_MEMORY_ACQ_REL, .memory_modes=nir_var_mem_shared); + + if (state.streamout_enabled) + ngg_gs_build_streamout(b, &state); + /* Lower the GS intrinsics */ lower_ngg_gs_intrinsics(shader, &state); b->cursor = nir_after_cf_list(&impl->body); diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index e9145443bb4..5a7e49ef6f4 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -1341,7 +1341,8 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_ assert(info->is_ngg); NIR_PASS_V(nir, ac_nir_lower_ngg_gs, info->wave_size, info->workgroup_size, info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size, - info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last, false); + info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last, + false, true); } else if (nir->info.stage == MESA_SHADER_MESH) { bool scratch_ring = false; NIR_PASS_V(nir, ac_nir_lower_ngg_ms, &scratch_ring, info->wave_size, pl_key->has_multiview_view_index);
