Module: Mesa
Branch: main
Commit: cca40086c6a43db1ad281d9b1e5f92f10f26acca
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=cca40086c6a43db1ad281d9b1e5f92f10f26acca

Author: Faith Ekstrand <faith.ekstr...@collabora.com>
Date:   Fri Nov 17 13:30:08 2023 -0600

nak: Lower scan/reduce in NIR

We can probably do slightly better than this if we take advantage of the
predicate destination in SHFL but not by much.  All of the insanity is
still required (nvidia basically emits this), we just might be able to
save ourslves a few comparison ops.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26264>

---

 src/nouveau/compiler/meson.build                 |   1 +
 src/nouveau/compiler/nak_nir.c                   |   1 +
 src/nouveau/compiler/nak_nir_lower_scan_reduce.c | 197 +++++++++++++++++++++++
 src/nouveau/compiler/nak_private.h               |   1 +
 4 files changed, 200 insertions(+)

diff --git a/src/nouveau/compiler/meson.build b/src/nouveau/compiler/meson.build
index 2a77a6e921c..a6c95c5be6d 100644
--- a/src/nouveau/compiler/meson.build
+++ b/src/nouveau/compiler/meson.build
@@ -22,6 +22,7 @@ libnak_c_files = files(
   'nak.h',
   'nak_nir.c',
   'nak_nir_add_barriers.c',
+  'nak_nir_lower_scan_reduce.c',
   'nak_nir_lower_tex.c',
   'nak_nir_lower_vtg_io.c',
   'nak_nir_lower_gs_intrinsics.c',
diff --git a/src/nouveau/compiler/nak_nir.c b/src/nouveau/compiler/nak_nir.c
index a6639eabd7b..ef748adfdb1 100644
--- a/src/nouveau/compiler/nak_nir.c
+++ b/src/nouveau/compiler/nak_nir.c
@@ -297,6 +297,7 @@ nak_preprocess_nir(nir_shader *nir, const struct 
nak_compiler *nak)
       .lower_inverse_ballot = true,
    };
    OPT(nir, nir_lower_subgroups, &subgroups_options);
+   OPT(nir, nak_nir_lower_scan_reduce);
 }
 
 static uint16_t
diff --git a/src/nouveau/compiler/nak_nir_lower_scan_reduce.c 
b/src/nouveau/compiler/nak_nir_lower_scan_reduce.c
new file mode 100644
index 00000000000..2d5923190d6
--- /dev/null
+++ b/src/nouveau/compiler/nak_nir_lower_scan_reduce.c
@@ -0,0 +1,197 @@
+/*
+ * Copyright © 2023 Collabora, Ltd.
+ * SPDX-License-Identifier: MIT
+ */
+
+#include "nak_private.h"
+#include "nir_builder.h"
+
+static nir_def *
+build_identity(nir_builder *b, nir_op op)
+{
+   nir_const_value ident_const = nir_alu_binop_identity(op, 32);
+   return nir_build_imm(b, 1, 32, &ident_const);
+}
+
+/* Implementation of scan/reduce that assumes a full subgroup */
+static nir_def *
+build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
+                nir_def *data, unsigned cluster_size)
+{
+   switch (op) {
+   case nir_intrinsic_exclusive_scan:
+   case nir_intrinsic_inclusive_scan: {
+      for (unsigned i = 1; i < cluster_size; i *= 2) {
+         nir_def *idx = nir_load_subgroup_invocation(b);
+         nir_def *has_buddy = nir_ige_imm(b, idx, i);
+
+         nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i));
+         nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
+         data = nir_bcsel(b, has_buddy, accum, data);
+      }
+
+      if (op == nir_intrinsic_exclusive_scan) {
+         /* For exclusive scans, we need to shift one more time and fill in the
+          * bottom channel with identity.
+          */
+         assert(cluster_size == 32);
+         nir_def *idx = nir_load_subgroup_invocation(b);
+         nir_def *has_buddy = nir_ige_imm(b, idx, 1);
+
+         nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, 1));
+         data = nir_bcsel(b, has_buddy, buddy_data, build_identity(b, red_op));
+      }
+
+      return data;
+   }
+
+   case nir_intrinsic_reduce: {
+      for (unsigned i = 1; i < cluster_size; i *= 2) {
+         nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i));
+         data = nir_build_alu2(b, red_op, data, buddy_data);
+      }
+      return data;
+   }
+
+   default:
+      unreachable("Unsupported scan/reduce op");
+   }
+}
+
+/* Fully generic implementation of scan/reduce that takes a mask */
+static nir_def *
+build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op,
+                  nir_def *data, nir_def *mask, unsigned max_mask_bits)
+{
+   nir_def *lt_mask = nir_load_subgroup_lt_mask(b, 1, 32);
+
+   /* Mask of all channels whose values we need to accumulate.  Our own value
+    * is already in accum, if inclusive, thanks to the initialization above.
+    * We only need to consider lower indexed invocations.
+    */
+   nir_def *remaining = nir_iand(b, mask, lt_mask);
+
+   for (unsigned i = 1; i < max_mask_bits; i *= 2) {
+      /* At each step, our buddy channel is the first channel we have yet to
+       * take into account in the accumulator.
+       */
+      nir_def *has_buddy = nir_ine_imm(b, remaining, 0);
+      nir_def *buddy = nir_ufind_msb(b, remaining);
+
+      /* Accumulate with our buddy channel, if any */
+      nir_def *buddy_data = nir_shuffle(b, data, buddy);
+      nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data);
+      data = nir_bcsel(b, has_buddy, accum, data);
+
+      /* We just took into account everything in our buddy's accumulator from
+       * the previous step.  The only things remaining are whatever channels
+       * were remaining for our buddy.
+       */
+      nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy);
+      remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0));
+   }
+
+   switch (op) {
+   case nir_intrinsic_exclusive_scan: {
+      /* For exclusive scans, we need to shift one more time and fill in the
+       * bottom channel with identity.
+       *
+       * Some of this will get CSE'd with the first step but that's okay. The
+       * code is cleaner this way.
+       */
+      nir_def *lower = nir_iand(b, mask, lt_mask);
+      nir_def *has_buddy = nir_ine_imm(b, lower, 0);
+      nir_def *buddy = nir_ufind_msb(b, lower);
+
+      nir_def *buddy_data = nir_shuffle(b, data, buddy);
+      return nir_bcsel(b, has_buddy, buddy_data, build_identity(b, red_op));
+   }
+
+   case nir_intrinsic_inclusive_scan:
+      return data;
+
+   case nir_intrinsic_reduce: {
+      /* For reductions, we need to take the top value of the scan */
+      nir_def *idx = nir_ufind_msb(b, mask);
+      return nir_shuffle(b, data, idx);
+   }
+
+   default:
+      unreachable("Unsupported scan/reduce op");
+   }
+}
+
+static bool
+nak_nir_lower_scan_reduce_intrin(nir_builder *b,
+                                 nir_intrinsic_instr *intrin,
+                                 UNUSED void *_data)
+{
+   switch (intrin->intrinsic) {
+   case nir_intrinsic_exclusive_scan:
+   case nir_intrinsic_inclusive_scan:
+   case nir_intrinsic_reduce:
+      break;
+   default:
+      return false;
+   }
+
+   const nir_op red_op = nir_intrinsic_reduction_op(intrin);
+
+   /* Grab the cluster size, defaulting to 32 */
+   unsigned cluster_size = 32;
+   if (nir_intrinsic_has_cluster_size(intrin)) {
+      cluster_size = nir_intrinsic_cluster_size(intrin);
+      if (cluster_size == 0 || cluster_size > 32)
+         cluster_size = 32;
+   }
+
+   b->cursor = nir_before_instr(&intrin->instr);
+
+   nir_def *data;
+   if (cluster_size == 1) {
+      /* Simple case where we're not actually doing any reducing at all. */
+      assert(intrin->intrinsic == nir_intrinsic_reduce);
+      data = intrin->src[0].ssa;
+   } else {
+      /* First, we need a mask of all invocations to be included in the
+       * reduction or scan.  For trivial cluster sizes, that's just the mask
+       * of enabled channels.
+       */
+      nir_def *mask = nir_ballot(b, 1, 32, nir_imm_true(b));
+      if (cluster_size < 32) {
+         nir_def *idx = nir_load_subgroup_invocation(b);
+         nir_def *cluster = nir_iand_imm(b, idx, ~(uint64_t)(cluster_size - 
1));
+
+         nir_def *cluster_mask = nir_imm_int(b, BITFIELD_MASK(cluster_size));
+         cluster_mask = nir_ishl(b, cluster_mask, cluster);
+
+         mask = nir_iand(b, mask, cluster_mask);
+      }
+
+      nir_def *full, *partial;
+      nir_push_if(b, nir_ieq_imm(b, mask, -1));
+      {
+         full = build_scan_full(b, intrin->intrinsic, red_op,
+                                intrin->src[0].ssa, cluster_size);
+      }
+      nir_push_else(b, NULL);
+      {
+         partial = build_scan_reduce(b, intrin->intrinsic, red_op,
+                                     intrin->src[0].ssa, mask, cluster_size);
+      }
+      nir_pop_if(b, NULL);
+      data = nir_if_phi(b, full, partial);
+   }
+
+   nir_def_rewrite_uses(&intrin->def, data);
+   nir_instr_remove(&intrin->instr);
+
+   return true;
+}
+
+bool
+nak_nir_lower_scan_reduce(nir_shader *nir)
+{
+   return nir_shader_intrinsics_pass(nir, nak_nir_lower_scan_reduce_intrin,
+                                     nir_metadata_none, NULL);
+}
diff --git a/src/nouveau/compiler/nak_private.h 
b/src/nouveau/compiler/nak_private.h
index b81b2eaef81..cd07cc42559 100644
--- a/src/nouveau/compiler/nak_private.h
+++ b/src/nouveau/compiler/nak_private.h
@@ -144,6 +144,7 @@ struct nak_nir_tex_flags {
    uint32_t pad:26;
 };
 
+bool nak_nir_lower_scan_reduce(nir_shader *shader);
 bool nak_nir_lower_tex(nir_shader *nir, const struct nak_compiler *nak);
 bool nak_nir_lower_gs_intrinsics(nir_shader *shader);
 

Reply via email to