Module: Mesa
Branch: main
Commit: 700a2dc648a5e73c20e456b5144418a8b405f985
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=700a2dc648a5e73c20e456b5144418a8b405f985

Author: Karol Herbst <[email protected]>
Date:   Fri Oct 20 19:47:13 2023 +0200

zink: alias nir scratch memory by lowering to common bit_size

This aliases each access as required by OpenCL. It's up to the vulkan
driver to vectorize to wider loads/stores if possible.

Signed-off-by: Karol Herbst <[email protected]>
Acked-by: Mike Blumenkrantz <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25837>

---

 src/gallium/drivers/zink/zink_compiler.c | 52 +++++++++++++++++++++++++++++++-
 1 file changed, 51 insertions(+), 1 deletion(-)

diff --git a/src/gallium/drivers/zink/zink_compiler.c 
b/src/gallium/drivers/zink/zink_compiler.c
index 95fe41ff4f9..646880b7d08 100644
--- a/src/gallium/drivers/zink/zink_compiler.c
+++ b/src/gallium/drivers/zink/zink_compiler.c
@@ -5311,6 +5311,55 @@ mem_access_size_align_cb(nir_intrinsic_op intrin, 
uint8_t bytes,
    };
 }
 
+static nir_mem_access_size_align
+mem_access_scratch_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes,
+                                 uint8_t bit_size, uint32_t align,
+                                 uint32_t align_offset, bool offset_is_const,
+                                 const void *cb_data)
+{
+   bit_size = *(const uint8_t *)cb_data;
+   align = nir_combined_align(align, align_offset);
+
+   assert(util_is_power_of_two_nonzero(align));
+
+   return (nir_mem_access_size_align){
+      .num_components = MIN2(bytes / (bit_size / 8), 4),
+      .bit_size = bit_size,
+      .align = bit_size / 8,
+   };
+}
+
+static bool
+alias_scratch_memory_scan_bit_size(struct nir_builder *b, nir_intrinsic_instr 
*instr, void *data)
+{
+   uint8_t *bit_size = data;
+   switch (instr->intrinsic) {
+   case nir_intrinsic_load_scratch:
+      *bit_size = MIN2(*bit_size, instr->def.bit_size);
+      return false;
+   case nir_intrinsic_store_scratch:
+      *bit_size = MIN2(*bit_size, instr->src[0].ssa->bit_size);
+      return false;
+   default:
+      return false;
+   }
+}
+
+static bool
+alias_scratch_memory(nir_shader *nir)
+{
+   uint8_t bit_size = 64;
+
+   nir_shader_intrinsics_pass(nir, alias_scratch_memory_scan_bit_size, 
nir_metadata_all, &bit_size);
+   nir_lower_mem_access_bit_sizes_options lower_scratch_mem_access_options = {
+      .modes = nir_var_function_temp,
+      .may_lower_unaligned_stores_to_atomics = true,
+      .callback = mem_access_scratch_size_align_cb,
+      .cb_data = &bit_size,
+   };
+   return nir_lower_mem_access_bit_sizes(nir, 
&lower_scratch_mem_access_options);
+}
+
 static uint8_t
 lower_vec816_alu(const nir_instr *instr, const void *cb_data)
 {
@@ -5358,12 +5407,13 @@ zink_shader_create(struct zink_screen *screen, struct 
nir_shader *nir)
 
    if (nir->info.stage == MESA_SHADER_KERNEL) {
       nir_lower_mem_access_bit_sizes_options lower_mem_access_options = {
-         .modes = nir_var_all,
+         .modes = nir_var_all ^ nir_var_function_temp,
          .may_lower_unaligned_stores_to_atomics = true,
          .callback = mem_access_size_align_cb,
          .cb_data = screen,
       };
       NIR_PASS_V(nir, nir_lower_mem_access_bit_sizes, 
&lower_mem_access_options);
+      NIR_PASS_V(nir, alias_scratch_memory);
       NIR_PASS_V(nir, nir_lower_alu_width, lower_vec816_alu, NULL);
       NIR_PASS_V(nir, nir_lower_alu_vec8_16_srcs);
    }

Reply via email to