Module: Mesa
Branch: main
Commit: e6d26cb288033ec3fcebb032ad41e487cce03a7c
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=e6d26cb288033ec3fcebb032ad41e487cce03a7c

Author: Rhys Perry <[email protected]>
Date:   Tue Oct 18 20:52:53 2022 +0100

nir,ac/nir,aco,radv: replace has_input_*_amd with more general intrinsics

Signed-off-by: Rhys Perry <[email protected]>
Reviewed-by: Qiang Yu <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19228>

---

 src/amd/common/ac_nir_lower_ngg.c                  | 41 ++++++++++++++--------
 src/amd/compiler/aco_instruction_selection.cpp     | 15 +++-----
 .../compiler/aco_instruction_selection_setup.cpp   |  2 --
 src/amd/llvm/ac_nir_to_llvm.c                      | 14 +++-----
 src/amd/vulkan/radv_nir_lower_abi.c                |  3 ++
 src/compiler/nir/nir_divergence_analysis.c         |  4 +--
 src/compiler/nir/nir_intrinsics.py                 | 11 +++---
 7 files changed, 47 insertions(+), 43 deletions(-)

diff --git a/src/amd/common/ac_nir_lower_ngg.c 
b/src/amd/common/ac_nir_lower_ngg.c
index 7782de9f3b2..424715b6498 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -441,12 +441,24 @@ emit_ngg_nogs_prim_exp_arg(nir_builder *b, 
lower_ngg_nogs_state *st)
    }
 }
 
+static nir_ssa_def *
+has_input_vertex(nir_builder *b)
+{
+   return nir_is_subgroup_invocation_lt_amd(b, 
nir_load_merged_wave_info_amd(b));
+}
+
+static nir_ssa_def *
+has_input_primitive(nir_builder *b)
+{
+   return nir_is_subgroup_invocation_lt_amd(b,
+                                            nir_ushr_imm(b, 
nir_load_merged_wave_info_amd(b), 8));
+}
+
 static void
 emit_ngg_nogs_prim_export(nir_builder *b, lower_ngg_nogs_state *st, 
nir_ssa_def *arg)
 {
-   nir_ssa_def *gs_thread = st->gs_accepted_var
-                            ? nir_load_var(b, st->gs_accepted_var)
-                            : nir_has_input_primitive_amd(b);
+   nir_ssa_def *gs_thread =
+      st->gs_accepted_var ? nir_load_var(b, st->gs_accepted_var) : 
has_input_primitive(b);
 
    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
    {
@@ -506,8 +518,8 @@ emit_ngg_nogs_prim_export(nir_builder *b, 
lower_ngg_nogs_state *st, nir_ssa_def
 static void
 emit_ngg_nogs_prim_id_store_shared(nir_builder *b, lower_ngg_nogs_state *st)
 {
-   nir_ssa_def *gs_thread = st->gs_accepted_var ?
-      nir_load_var(b, st->gs_accepted_var) : nir_has_input_primitive_amd(b);
+   nir_ssa_def *gs_thread =
+      st->gs_accepted_var ? nir_load_var(b, st->gs_accepted_var) : 
has_input_primitive(b);
 
    nir_if *if_gs_thread = nir_push_if(b, gs_thread);
    {
@@ -986,8 +998,8 @@ compact_vertices_after_culling(nir_builder *b,
    nir_pop_if(b, if_gs_accepted);
 
    nir_store_var(b, es_accepted_var, es_survived, 0x1u);
-   nir_store_var(b, gs_accepted_var,
-                 nir_iand(b, nir_inot(b, fully_culled), 
nir_has_input_primitive_amd(b)), 0x1u);
+   nir_store_var(b, gs_accepted_var, nir_iand(b, nir_inot(b, fully_culled), 
has_input_primitive(b)),
+                 0x1u);
 }
 
 static void
@@ -1359,7 +1371,7 @@ add_deferred_attribute_culling(nir_builder *b, 
nir_cf_list *original_extracted_c
 
    b->cursor = nir_before_cf_list(&impl->body);
 
-   nir_ssa_def *es_thread = nir_has_input_vertex_amd(b);
+   nir_ssa_def *es_thread = has_input_vertex(b);
    nir_if *if_es_thread = nir_push_if(b, es_thread);
    {
       /* Initialize the position output variable to zeroes, in case not all 
VS/TES invocations store the output.
@@ -1392,7 +1404,8 @@ add_deferred_attribute_culling(nir_builder *b, 
nir_cf_list *original_extracted_c
    nir_pop_if(b, if_es_thread);
 
    nir_store_var(b, es_accepted_var, es_thread, 0x1u);
-   nir_store_var(b, gs_accepted_var, nir_has_input_primitive_amd(b), 0x1u);
+   nir_ssa_def *gs_thread = has_input_primitive(b);
+   nir_store_var(b, gs_accepted_var, gs_thread, 0x1u);
 
    /* Remove all non-position outputs, and put the position output into the 
variable. */
    nir_metadata_preserve(impl, nir_metadata_none);
@@ -1414,7 +1427,7 @@ add_deferred_attribute_culling(nir_builder *b, 
nir_cf_list *original_extracted_c
       nir_ssa_def *es_vertex_lds_addr = pervertex_lds_addr(b, 
invocation_index, pervertex_lds_bytes);
 
       /* ES invocations store their vertex data to LDS for GS threads to read. 
*/
-      if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b));
+      if_es_thread = nir_push_if(b, es_thread);
       if_es_thread->control = nir_selection_control_divergent_always_taken;
       {
          /* Store position components that are relevant to culling in LDS */
@@ -1440,7 +1453,7 @@ add_deferred_attribute_culling(nir_builder *b, 
nir_cf_list *original_extracted_c
       nir_store_var(b, prim_exp_arg_var, nir_imm_int(b, 1u << 31), 0x1u);
 
       /* GS invocations load the vertex data and perform the culling. */
-      nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b));
+      nir_if *if_gs_thread = nir_push_if(b, gs_thread);
       {
          /* Load vertex indices from input VGPRs */
          nir_ssa_def *vtx_idx[3] = {0};
@@ -1492,7 +1505,7 @@ add_deferred_attribute_culling(nir_builder *b, 
nir_cf_list *original_extracted_c
       nir_store_var(b, es_accepted_var, nir_imm_bool(b, false), 0x1u);
 
       /* ES invocations load their accepted flag from LDS. */
-      if_es_thread = nir_push_if(b, nir_has_input_vertex_amd(b));
+      if_es_thread = nir_push_if(b, es_thread);
       if_es_thread->control = nir_selection_control_divergent_always_taken;
       {
          nir_ssa_def *accepted = nir_load_shared(b, 1, 8u, es_vertex_lds_addr, 
.base = lds_es_vertex_accepted, .align_mul = 4u);
@@ -2021,7 +2034,7 @@ ac_nir_lower_ngg_nogs(nir_shader *shader, const 
ac_nir_lower_ngg_options *option
 
    nir_intrinsic_instr *export_vertex_instr;
    nir_ssa_def *es_thread =
-      options->can_cull ? nir_load_var(b, es_accepted_var) : 
nir_has_input_vertex_amd(b);
+      options->can_cull ? nir_load_var(b, es_accepted_var) : 
has_input_vertex(b);
 
    nir_if *if_es_thread = nir_push_if(b, es_thread);
    {
@@ -2972,7 +2985,7 @@ ac_nir_lower_ngg_gs(nir_shader *shader, const 
ac_nir_lower_ngg_options *options)
    state.lds_addr_gs_scratch = nir_load_lds_ngg_scratch_base_amd(b);
 
    /* Wrap the GS control flow. */
-   nir_if *if_gs_thread = nir_push_if(b, nir_has_input_primitive_amd(b));
+   nir_if *if_gs_thread = nir_push_if(b, has_input_primitive(b));
 
    nir_cf_reinsert(&extracted, b->cursor);
    b->cursor = nir_after_cf_list(&if_gs_thread->then_list);
diff --git a/src/amd/compiler/aco_instruction_selection.cpp 
b/src/amd/compiler/aco_instruction_selection.cpp
index 5948aa0640d..91d8c5c69d0 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -8250,6 +8250,7 @@ emit_interp_center(isel_context* ctx, Temp dst, Temp 
bary, Temp pos1, Temp pos2)
 }
 
 Temp merged_wave_info_to_mask(isel_context* ctx, unsigned i);
+Temp lanecount_to_mask(isel_context* ctx, Temp count);
 void ngg_emit_sendmsg_gs_alloc_req(isel_context* ctx, Temp vtx_cnt, Temp 
prm_cnt);
 static void create_primitive_exports(isel_context *ctx, Temp prim_ch1);
 static void create_vs_exports(isel_context* ctx);
@@ -9140,11 +9141,9 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* 
instr)
       /* unused in the legacy pipeline, the HW keeps track of this for us */
       break;
    }
-   case nir_intrinsic_has_input_vertex_amd:
-   case nir_intrinsic_has_input_primitive_amd: {
-      assert(ctx->stage.hw == HWStage::NGG);
-      unsigned i = instr->intrinsic == nir_intrinsic_has_input_vertex_amd ? 0 
: 1;
-      bld.copy(Definition(get_ssa_temp(ctx, &instr->dest.ssa)), 
merged_wave_info_to_mask(ctx, i));
+   case nir_intrinsic_is_subgroup_invocation_lt_amd: {
+      Temp src = bld.as_uniform(get_ssa_temp(ctx, instr->src[0].ssa));
+      bld.copy(Definition(get_ssa_temp(ctx, &instr->dest.ssa)), 
lanecount_to_mask(ctx, src));
       break;
    }
    case nir_intrinsic_export_vertex_amd: {
@@ -11777,7 +11776,7 @@ cleanup_cfg(Program* program)
 }
 
 Temp
-lanecount_to_mask(isel_context* ctx, Temp count, bool allow64 = true)
+lanecount_to_mask(isel_context* ctx, Temp count)
 {
    assert(count.regClass() == s1);
 
@@ -11786,10 +11785,6 @@ lanecount_to_mask(isel_context* ctx, Temp count, bool 
allow64 = true)
    Temp cond;
 
    if (ctx->program->wave_size == 64) {
-      /* If we know that all 64 threads can't be active at a time, we just use 
the mask as-is */
-      if (!allow64)
-         return mask;
-
       /* Special case for 64 active invocations, because 64 doesn't work with 
s_bfm */
       Temp active_64 = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), 
count,
                                 Operand::c32(6u /* log2(64) */));
diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp 
b/src/amd/compiler/aco_instruction_selection_setup.cpp
index eda0d7f7e1a..61064622988 100644
--- a/src/amd/compiler/aco_instruction_selection_setup.cpp
+++ b/src/amd/compiler/aco_instruction_selection_setup.cpp
@@ -597,8 +597,6 @@ init_context(isel_context* ctx, nir_shader* shader)
                case nir_intrinsic_first_invocation:
                case nir_intrinsic_ballot:
                case nir_intrinsic_bindless_image_samples:
-               case nir_intrinsic_has_input_vertex_amd:
-               case nir_intrinsic_has_input_primitive_amd:
                case nir_intrinsic_load_force_vrs_rates_amd:
                case nir_intrinsic_load_scalar_arg_amd:
                case nir_intrinsic_load_smem_amd: type = RegType::sgpr; break;
diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c
index 05a1aee69a3..6ad93886bbd 100644
--- a/src/amd/llvm/ac_nir_to_llvm.c
+++ b/src/amd/llvm/ac_nir_to_llvm.c
@@ -4283,16 +4283,10 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, 
nir_intrinsic_instr *ins
       else
          result = ctx->ac.i32_0;
       break;
-   case nir_intrinsic_has_input_vertex_amd: {
-      LLVMValueRef num =
-         ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, 
ctx->args->merged_wave_info), 0, 8);
-      result = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, 
ac_get_thread_id(&ctx->ac), num, "");
-      break;
-   }
-   case nir_intrinsic_has_input_primitive_amd: {
-      LLVMValueRef num =
-         ac_unpack_param(&ctx->ac, ac_get_arg(&ctx->ac, 
ctx->args->merged_wave_info), 8, 8);
-      result = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, 
ac_get_thread_id(&ctx->ac), num, "");
+   case nir_intrinsic_is_subgroup_invocation_lt_amd: {
+      LLVMValueRef count = LLVMBuildAnd(ctx->ac.builder, get_src(ctx, 
instr->src[0]),
+                                        LLVMConstInt(ctx->ac.i32, 0xff, 0), 
"");
+      result = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, 
ac_get_thread_id(&ctx->ac), count, "");
       break;
    }
    case nir_intrinsic_load_workgroup_num_input_vertices_amd:
diff --git a/src/amd/vulkan/radv_nir_lower_abi.c 
b/src/amd/vulkan/radv_nir_lower_abi.c
index 478796706ea..5ae3e977e9e 100644
--- a/src/amd/vulkan/radv_nir_lower_abi.c
+++ b/src/amd/vulkan/radv_nir_lower_abi.c
@@ -208,6 +208,9 @@ lower_abi_instr(nir_builder *b, nir_instr *instr, void 
*state)
    case nir_intrinsic_load_prim_xfb_query_enabled_amd:
       replacement = ngg_query_bool_setting(b, radv_ngg_query_prim_xfb, s);
       break;
+   case nir_intrinsic_load_merged_wave_info_amd:
+      replacement = ac_nir_load_arg(b, &s->args->ac, 
s->args->ac.merged_wave_info);
+      break;
    case nir_intrinsic_load_cull_any_enabled_amd:
       replacement = nggc_bool_setting(
          b, radv_nggc_front_face | radv_nggc_back_face | 
radv_nggc_small_primitives, s);
diff --git a/src/compiler/nir/nir_divergence_analysis.c 
b/src/compiler/nir/nir_divergence_analysis.c
index b916e63b47a..df4f4e7ddd2 100644
--- a/src/compiler/nir/nir_divergence_analysis.c
+++ b/src/compiler/nir/nir_divergence_analysis.c
@@ -171,6 +171,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr 
*instr)
    case nir_intrinsic_load_pipeline_stat_query_enabled_amd:
    case nir_intrinsic_load_prim_gen_query_enabled_amd:
    case nir_intrinsic_load_prim_xfb_query_enabled_amd:
+   case nir_intrinsic_load_merged_wave_info_amd:
    case nir_intrinsic_load_cull_front_face_enabled_amd:
    case nir_intrinsic_load_cull_back_face_enabled_amd:
    case nir_intrinsic_load_cull_ccw_amd:
@@ -642,8 +643,7 @@ visit_intrinsic(nir_shader *shader, nir_intrinsic_instr 
*instr)
    case nir_intrinsic_load_tlb_color_v3d:
    case nir_intrinsic_load_tess_rel_patch_id_amd:
    case nir_intrinsic_load_gs_vertex_offset_amd:
-   case nir_intrinsic_has_input_vertex_amd:
-   case nir_intrinsic_has_input_primitive_amd:
+   case nir_intrinsic_is_subgroup_invocation_lt_amd:
    case nir_intrinsic_load_packed_passthrough_primitive_amd:
    case nir_intrinsic_load_initial_edgeflags_amd:
    case nir_intrinsic_gds_atomic_add_amd:
diff --git a/src/compiler/nir/nir_intrinsics.py 
b/src/compiler/nir/nir_intrinsics.py
index 8cea094b1a8..4bc16b5ed4d 100644
--- a/src/compiler/nir/nir_intrinsics.py
+++ b/src/compiler/nir/nir_intrinsics.py
@@ -1375,11 +1375,9 @@ system_value("streamout_offset_amd", 1, indices=[BASE])
 
 # AMD merged shader intrinsics
 
-# Whether the current invocation has an input vertex / primitive to process 
(also known as "ES thread" or "GS thread").
-# Not safe to reorder because it changes after 
overwrite_subgroup_num_vertices_and_primitives_amd.
-# Also, the generated code is more optimal if they are not CSE'd.
-intrinsic("has_input_vertex_amd", src_comp=[], dest_comp=1, bit_sizes=[1], 
indices=[])
-intrinsic("has_input_primitive_amd", src_comp=[], dest_comp=1, bit_sizes=[1], 
indices=[])
+# Whether the current invocation index in the subgroup is less than the 
source. The source must be
+# subgroup uniform and bits 0-7 must be less than or equal to the wave size.
+intrinsic("is_subgroup_invocation_lt_amd", src_comp=[1], dest_comp=1, 
bit_sizes=[1], flags=[CAN_ELIMINATE])
 
 # AMD NGG intrinsics
 
@@ -1395,6 +1393,9 @@ system_value("pipeline_stat_query_enabled_amd", 
dest_comp=1, bit_sizes=[1])
 system_value("prim_gen_query_enabled_amd", dest_comp=1, bit_sizes=[1])
 # Whether NGG should execute shader query for primitive streamouted.
 system_value("prim_xfb_query_enabled_amd", dest_comp=1, bit_sizes=[1])
+# Merged wave info. Bits 0-7 are the ES thread count, 8-15 are the GS thread 
count, 16-24 is the
+# GS Wave ID, 24-27 is the wave index in the workgroup, and 28-31 is the 
workgroup size in waves.
+system_value("merged_wave_info_amd", dest_comp=1)
 # Whether the shader should cull front facing triangles.
 intrinsic("load_cull_front_face_enabled_amd", dest_comp=1, bit_sizes=[1], 
flags=[CAN_ELIMINATE])
 # Whether the shader should cull back facing triangles.

Reply via email to