Module: Mesa
Branch: main
Commit: 75dbb404393a5ae99adb90a156fa5a084aa79c4d
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=75dbb404393a5ae99adb90a156fa5a084aa79c4d

Author: Timur Kristóf <[email protected]>
Date:   Thu Sep  9 10:33:50 2021 +0200

ac/nir: Remove byte permute from prefix sum of the repack sequence.

The byte-permute instruction v_perm_b32 is not exposed by older
LLVM releases (only available on LLVM 13 and later), therefore a new
sequence is needed which we can use with these LLVM versions too.

The prefix sum is replaced by two alternatives:

1. For GPUs that support v_dot, we shift 0x01 to the wanted byte
positions and then use v_dot to sum the results.

2. For older GPUs (Navi 10), we simply shift out the unwanted bytes
and use v_sad_u8 to produce the sum.

Signed-off-by: Timur Kristóf <[email protected]>
Acked-by: Marek Olšák <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12786>

---

 src/amd/common/ac_nir_lower_ngg.c | 100 +++++++++++++++++++++++++-------------
 1 file changed, 65 insertions(+), 35 deletions(-)

diff --git a/src/amd/common/ac_nir_lower_ngg.c 
b/src/amd/common/ac_nir_lower_ngg.c
index 6e63e153fb8..a3d9416bb19 100644
--- a/src/amd/common/ac_nir_lower_ngg.c
+++ b/src/amd/common/ac_nir_lower_ngg.c
@@ -133,6 +133,70 @@ typedef struct {
    nir_ssa_def *repacked_invocation_index;
 } wg_repack_result;
 
+/**
+ * Computes a horizontal sum of 8-bit packed values loaded from LDS.
+ *
+ * Each lane N will sum packed bytes 0 to N-1.
+ * We only care about the results from up to wave_id+1 lanes.
+ * (Other lanes are not deactivated but their calculation is not used.)
+ */
+static nir_ssa_def *
+summarize_repack(nir_builder *b, nir_ssa_def *packed_counts, unsigned 
num_lds_dwords)
+{
+   /* We'll use shift to filter out the bytes not needed by the current lane.
+    *
+    * Need to shift by: num_lds_dwords * 4 - lane_id (in bytes).
+    * However, two shifts are needed because one can't go all the way,
+    * so the shift amount is half that (and in bits).
+    *
+    * When v_dot4_u32_u8 is available, we right-shift a series of 0x01 bytes.
+    * This will yield 0x01 at wanted byte positions and 0x00 at unwanted 
positions,
+    * therefore v_dot can get rid of the unneeded values.
+    * This sequence is preferable because it better hides the latency of the 
LDS.
+    *
+    * If the v_dot instruction can't be used, we left-shift the packed bytes.
+    * This will shift out the unneeded bytes and shift in zeroes instead,
+    * then we sum them using v_sad_u8.
+    */
+
+   nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
+   nir_ssa_def *shift = nir_iadd_imm_nuw(b, nir_imul_imm(b, lane_id, -4u), 
num_lds_dwords * 16);
+   bool use_dot = b->shader->options->has_dot_4x8;
+
+   if (num_lds_dwords == 1) {
+      nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, 
nir_imm_int(b, 0x01010101), shift), shift);
+
+      /* Broadcast the packed data we read from LDS (to the first 16 lanes, 
but we only care up to num_waves). */
+      nir_ssa_def *packed = nir_build_lane_permute_16_amd(b, packed_counts, 
nir_imm_int(b, 0), nir_imm_int(b, 0));
+
+      /* Horizontally add the packed bytes. */
+      if (use_dot) {
+         return nir_udot_4x8_uadd(b, packed, dot_op, nir_imm_int(b, 0));
+      } else {
+         nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, packed, shift), shift);
+         return nir_sad_u8x4(b, sad_op, nir_imm_int(b, 0), nir_imm_int(b, 0));
+      }
+   } else if (num_lds_dwords == 2) {
+      nir_ssa_def *dot_op = !use_dot ? NULL : nir_ushr(b, nir_ushr(b, 
nir_imm_int64(b, 0x0101010101010101), shift), shift);
+
+      /* Broadcast the packed data we read from LDS (to the first 16 lanes, 
but we only care up to num_waves). */
+      nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, 
nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 
0));
+      nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, 
nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 
0));
+
+      /* Horizontally add the packed bytes. */
+      if (use_dot) {
+         nir_ssa_def *sum = nir_udot_4x8_uadd(b, packed_dw0, 
nir_unpack_64_2x32_split_x(b, dot_op), nir_imm_int(b, 0));
+         return nir_udot_4x8_uadd(b, packed_dw1, nir_unpack_64_2x32_split_y(b, 
dot_op), sum);
+      } else {
+         nir_ssa_def *sad_op = nir_ishl(b, nir_ishl(b, 
nir_pack_64_2x32_split(b, packed_dw0, packed_dw1), shift), shift);
+         nir_ssa_def *sum = nir_sad_u8x4(b, nir_unpack_64_2x32_split_x(b, 
sad_op), nir_imm_int(b, 0), nir_imm_int(b, 0));
+         return nir_sad_u8x4(b, nir_unpack_64_2x32_split_y(b, sad_op), 
nir_imm_int(b, 0), sum);
+      }
+   } else {
+      unreachable("Unimplemented NGG wave count");
+   }
+}
+
 /**
  * Repacks invocations in the current workgroup to eliminate gaps between them.
  *
@@ -208,41 +272,7 @@ repack_invocations_in_workgroup(nir_builder *b, 
nir_ssa_def *input_bool,
     */
 
    nir_ssa_def *num_waves = nir_build_load_num_subgroups(b);
-
-   /* sel = 0x01010101 * lane_id + 0x03020100 */
-   nir_ssa_def *lane_id = nir_load_subgroup_invocation(b);
-   nir_ssa_def *packed_id = nir_build_byte_permute_amd(b, nir_imm_int(b, 0), 
lane_id, nir_imm_int(b, 0));
-   nir_ssa_def *sel = nir_iadd_imm_nuw(b, packed_id, 0x03020100);
-   nir_ssa_def *sum = NULL;
-
-   if (num_lds_dwords == 1) {
-      /* Broadcast the packed data we read from LDS (to the first 16 lanes, 
but we only care up to num_waves). */
-      nir_ssa_def *packed_dw = nir_build_lane_permute_16_amd(b, packed_counts, 
nir_imm_int(b, 0), nir_imm_int(b, 0));
-
-      /* Use byte-permute to filter out the bytes not needed by the current 
lane. */
-      nir_ssa_def *filtered_packed = nir_build_byte_permute_amd(b, packed_dw, 
nir_imm_int(b, 0), sel);
-
-      /* Horizontally add the packed bytes. */
-      sum = nir_sad_u8x4(b, filtered_packed, nir_imm_int(b, 0), nir_imm_int(b, 
0));
-   } else if (num_lds_dwords == 2) {
-      /* Create selectors for the byte-permutes below. */
-      nir_ssa_def *dw0_selector = nir_build_lane_permute_16_amd(b, sel, 
nir_imm_int(b, 0x44443210), nir_imm_int(b, 0x4));
-      nir_ssa_def *dw1_selector = nir_build_lane_permute_16_amd(b, sel, 
nir_imm_int(b, 0x32100000), nir_imm_int(b, 0x4));
-
-      /* Broadcast the packed data we read from LDS (to the first 16 lanes, 
but we only care up to num_waves). */
-      nir_ssa_def *packed_dw0 = nir_build_lane_permute_16_amd(b, 
nir_unpack_64_2x32_split_x(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 
0));
-      nir_ssa_def *packed_dw1 = nir_build_lane_permute_16_amd(b, 
nir_unpack_64_2x32_split_y(b, packed_counts), nir_imm_int(b, 0), nir_imm_int(b, 
0));
-
-      /* Use byte-permute to filter out the bytes not needed by the current 
lane. */
-      nir_ssa_def *filtered_packed_dw0 = nir_build_byte_permute_amd(b, 
packed_dw0, nir_imm_int(b, 0), dw0_selector);
-      nir_ssa_def *filtered_packed_dw1 = nir_build_byte_permute_amd(b, 
packed_dw1, nir_imm_int(b, 0), dw1_selector);
-
-      /* Horizontally add the packed bytes. */
-      sum = nir_sad_u8x4(b, filtered_packed_dw0, nir_imm_int(b, 0), 
nir_imm_int(b, 0));
-      sum = nir_sad_u8x4(b, filtered_packed_dw1, nir_imm_int(b, 0), sum);
-   } else {
-      unreachable("Unimplemented NGG wave count");
-   }
+   nir_ssa_def *sum = summarize_repack(b, packed_counts, num_lds_dwords);
 
    nir_ssa_def *wg_repacked_index_base = nir_build_read_invocation(b, sum, 
wave_id);
    nir_ssa_def *wg_num_repacked_invocations = nir_build_read_invocation(b, 
sum, num_waves);

Reply via email to