Module: Mesa
Branch: main
Commit: 57dec0678e48a10df756f36f910decd1ec0552a7
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=57dec0678e48a10df756f36f910decd1ec0552a7

Author: Samuel Pitoiset <[email protected]>
Date:   Wed Sep 20 17:03:29 2023 +0200

ac/nir: add lowering for mesh shader queries

Signed-off-by: Samuel Pitoiset <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25331>

---

 src/amd/common/ac_nir.h           |  3 ++-
 src/amd/common/ac_nir_lower_ngg.c | 50 ++++++++++++++++++++++++++++++++++++++-
 src/amd/vulkan/radv_shader.c      |  2 +-
 3 files changed, 52 insertions(+), 3 deletions(-)

diff --git a/src/amd/common/ac_nir.h b/src/amd/common/ac_nir.h
index c9072c863bd..7c560e3b563 100644
--- a/src/amd/common/ac_nir.h
+++ b/src/amd/common/ac_nir.h
@@ -195,7 +195,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
                     bool has_param_exports,
                     bool *out_needs_scratch_ring,
                     unsigned wave_size,
-                    bool multiview);
+                    bool multiview,
+                    bool has_query);
 
 void
 ac_nir_lower_task_outputs_to_mem(nir_shader *shader,
diff --git a/src/amd/common/ac_nir_lower_ngg.c 
b/src/amd/common/ac_nir_lower_ngg.c
index e4a850a09ec..0872c1e2f15 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -227,6 +227,9 @@ typedef struct
    uint32_t clipdist_enable_mask;
    const uint8_t *vs_output_param_offset;
    bool has_param_exports;
+
+   /* True if the lowering needs to insert shader query. */
+   bool has_query;
 } lower_ngg_ms_state;
 
 /* Per-vertex LDS layout of culling shaders */
@@ -4401,6 +4404,45 @@ ms_prim_exp_arg_ch2(nir_builder *b, uint64_t 
outputs_mask, lower_ngg_ms_state *s
    return prim_exp_arg_ch2;
 }
 
+static void
+ms_prim_gen_query(nir_builder *b,
+                  nir_def *invocation_index,
+                  nir_def *num_prm,
+                  lower_ngg_ms_state *s)
+{
+   if (!s->has_query)
+      return;
+
+   nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, 
invocation_index, 0));
+   {
+      nir_if *if_shader_query = nir_push_if(b, 
nir_load_prim_gen_query_enabled_amd(b));
+      {
+         nir_atomic_add_gen_prim_count_amd(b, num_prm, .stream_id = 0);
+      }
+      nir_pop_if(b, if_shader_query);
+   }
+   nir_pop_if(b, if_invocation_index_zero);
+}
+
+static void
+ms_invocation_query(nir_builder *b,
+                    nir_def *invocation_index,
+                    lower_ngg_ms_state *s)
+{
+   if (!s->has_query)
+      return;
+
+   nir_if *if_invocation_index_zero = nir_push_if(b, nir_ieq_imm(b, 
invocation_index, 0));
+   {
+      nir_if *if_pipeline_query = nir_push_if(b, 
nir_load_pipeline_stat_query_enabled_amd(b));
+      {
+         nir_atomic_add_shader_invocation_count_amd(b, nir_imm_int(b, 
s->api_workgroup_size));
+      }
+      nir_pop_if(b, if_pipeline_query);
+   }
+   nir_pop_if(b, if_invocation_index_zero);
+}
+
 static void
 ms_emit_primitive_export(nir_builder *b,
                          nir_def *invocation_index,
@@ -4435,6 +4477,8 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
 
    nir_def *invocation_index = nir_load_local_invocation_index(b);
 
+   ms_prim_gen_query(b, invocation_index, num_prm, s);
+
    /* Load vertex/primitive attributes from shared memory and
     * emit store_output intrinsics for them.
     *
@@ -4664,6 +4708,8 @@ handle_smaller_ms_api_workgroup(nir_builder *b,
                                .memory_semantics = NIR_MEMORY_ACQ_REL,
                                .memory_modes = nir_var_shader_out | 
nir_var_mem_shared);
       }
+
+      ms_invocation_query(b, invocation_index, s);
    }
    nir_pop_if(b, if_has_api_ms_invocation);
 
@@ -4832,7 +4878,8 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
                     bool has_param_exports,
                     bool *out_needs_scratch_ring,
                     unsigned wave_size,
-                    bool multiview)
+                    bool multiview,
+                    bool has_query)
 {
    unsigned vertices_per_prim =
       num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type);
@@ -4886,6 +4933,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
       .clipdist_enable_mask = clipdist_enable_mask,
       .vs_output_param_offset = vs_output_param_offset,
       .has_param_exports = has_param_exports,
+      .has_query = has_query,
    };
 
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c
index eecfe43acf2..4d718652425 100644
--- a/src/amd/vulkan/radv_shader.c
+++ b/src/amd/vulkan/radv_shader.c
@@ -912,7 +912,7 @@ radv_lower_ngg(struct radv_device *device, struct 
radv_shader_stage *ngg_stage,
       bool scratch_ring = false;
       NIR_PASS_V(nir, ac_nir_lower_ngg_ms, options.gfx_level, 
options.clipdist_enable_mask,
                  options.vs_output_param_offset, options.has_param_exports, 
&scratch_ring, info->wave_size,
-                 pl_key->has_multiview_view_index);
+                 pl_key->has_multiview_view_index, false);
       ngg_stage->info.ms.needs_ms_scratch_ring = scratch_ring;
    } else {
       unreachable("invalid SW stage passed to radv_lower_ngg");

Reply via email to