Module: Mesa
Branch: main
Commit: 69d123d88eeca4b4ad52c78a3eb566f163721273
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=69d123d88eeca4b4ad52c78a3eb566f163721273

Author: Mike Blumenkrantz <[email protected]>
Date:   Tue Sep  6 14:25:38 2022 -0400

zink: fix sharedmem ops with bit_size!=32

* the rewrite_bo_access compiler pass already handles 64bit rewrites as-needed
* sharedmem access is not required to be 32bit

thus, this can use a similar methodology as ssbo/ubo vars to index based on 
bitsize
and handle operations through sized variables

Acked-by: Erik Faye-Lund <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18449>

---

 .../drivers/zink/nir_to_spirv/nir_to_spirv.c       | 74 ++++++++++++----------
 1 file changed, 40 insertions(+), 34 deletions(-)

diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c 
b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
index 555147a096e..ab6be0336e3 100644
--- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
+++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
@@ -87,6 +87,8 @@ struct ntv_context {
    bool block_started;
    SpvId loop_break, loop_cont;
 
+   SpvId shared_block_var[5]; //8, 16, 32, unused, 64
+
    SpvId front_face_var, instance_id_var, vertex_id_var,
          primitive_id_var, invocation_id_var, // geometry
          sample_mask_type, sample_id_var, sample_pos_var, sample_mask_in_var,
@@ -96,7 +98,6 @@ struct ntv_context {
          local_invocation_id_var, global_invocation_id_var,
          local_invocation_index_var, helper_invocation_var,
          local_group_size_var,
-         shared_block_var,
          base_vertex_var, base_instance_var, draw_id_var;
 
    SpvId subgroup_eq_mask_var,
@@ -455,21 +456,33 @@ get_glsl_type(struct ntv_context *ctx, const struct 
glsl_type *type)
 }
 
 static void
-create_shared_block(struct ntv_context *ctx, unsigned shared_size)
+create_shared_block(struct ntv_context *ctx, unsigned shared_size, unsigned 
bit_size)
 {
-   SpvId type = spirv_builder_type_uint(&ctx->builder, 32);
-   SpvId array = spirv_builder_type_array(&ctx->builder, type, 
emit_uint_const(ctx, 32, shared_size / 4));
-   spirv_builder_emit_array_stride(&ctx->builder, array, 4);
+   unsigned idx = bit_size >> 4;
+   SpvId type = spirv_builder_type_uint(&ctx->builder, bit_size);
+   unsigned block_size = shared_size / (bit_size / 8);
+   assert(block_size);
+   SpvId array = spirv_builder_type_array(&ctx->builder, type, 
emit_uint_const(ctx, 32, block_size));
+   spirv_builder_emit_array_stride(&ctx->builder, array, bit_size / 8);
    SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
                                                SpvStorageClassWorkgroup,
                                                array);
-   ctx->shared_block_var = spirv_builder_emit_var(&ctx->builder, ptr_type, 
SpvStorageClassWorkgroup);
+   ctx->shared_block_var[idx] = spirv_builder_emit_var(&ctx->builder, 
ptr_type, SpvStorageClassWorkgroup);
    if (ctx->spirv_1_4_interfaces) {
       assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
-      ctx->entry_ifaces[ctx->num_entry_ifaces++] = ctx->shared_block_var;
+      ctx->entry_ifaces[ctx->num_entry_ifaces++] = ctx->shared_block_var[idx];
    }
 }
 
+static SpvId
+get_shared_block(struct ntv_context *ctx, unsigned bit_size)
+{
+   unsigned idx = bit_size >> 4;
+   if (!ctx->shared_block_var[idx])
+      create_shared_block(ctx, ctx->nir->info.shared_size, bit_size);
+   return ctx->shared_block_var[idx];
+}
+
 #define HANDLE_EMIT_BUILTIN(SLOT, BUILTIN) \
       case VARYING_SLOT_##SLOT: \
          spirv_builder_emit_builtin(&ctx->builder, var_id, 
SpvBuiltIn##BUILTIN); \
@@ -2433,18 +2446,19 @@ emit_load_shared(struct ntv_context *ctx, 
nir_intrinsic_instr *intr)
    SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint);
    unsigned num_components = nir_dest_num_components(intr->dest);
    unsigned bit_size = nir_dest_bit_size(intr->dest);
-   SpvId uint_type = get_uvec_type(ctx, 32, 1);
+   SpvId uint_type = get_uvec_type(ctx, bit_size, 1);
    SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
                                                SpvStorageClassWorkgroup,
                                                uint_type);
    SpvId offset = get_src(ctx, &intr->src[0]);
    SpvId constituents[NIR_MAX_VEC_COMPONENTS];
+   SpvId shared_block = get_shared_block(ctx, bit_size);
    /* need to convert array -> vec */
    for (unsigned i = 0; i < num_components; i++) {
       SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type,
-                                                     ctx->shared_block_var, 
&offset, 1);
+                                                     shared_block, &offset, 1);
       constituents[i] = spirv_builder_emit_load(&ctx->builder, uint_type, 
member);
-      offset = emit_binop(ctx, SpvOpIAdd, uint_type, offset, 
emit_uint_const(ctx, 32, 1));
+      offset = emit_binop(ctx, SpvOpIAdd, 
spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, 
1));
    }
    SpvId result;
    if (num_components > 1)
@@ -2458,31 +2472,24 @@ static void
 emit_store_shared(struct ntv_context *ctx, nir_intrinsic_instr *intr)
 {
    SpvId src = get_src(ctx, &intr->src[0]);
-   bool qword = nir_src_bit_size(intr->src[0]) == 64;
 
-   unsigned num_writes = util_bitcount(nir_intrinsic_write_mask(intr));
    unsigned wrmask = nir_intrinsic_write_mask(intr);
-   /* this is a partial write, so we have to loop and do a per-component write 
*/
-   SpvId uint_type = get_uvec_type(ctx, 32, 1);
+   unsigned bit_size = nir_src_bit_size(intr->src[0]);
+   SpvId uint_type = get_uvec_type(ctx, bit_size, 1);
    SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
                                                SpvStorageClassWorkgroup,
                                                uint_type);
    SpvId offset = get_src(ctx, &intr->src[1]);
-
-   for (unsigned i = 0; num_writes; i++) {
-      if ((wrmask >> i) & 1) {
-         for (unsigned j = 0; j < 1 + !!qword; j++) {
-            unsigned comp = ((1 + !!qword) * i) + j;
-            SpvId shared_offset = emit_binop(ctx, SpvOpIAdd, uint_type, 
offset, emit_uint_const(ctx, 32, comp));
-            SpvId val = src;
-            if (nir_src_num_components(intr->src[0]) != 1 || qword)
-               val = spirv_builder_emit_composite_extract(&ctx->builder, 
uint_type, src, &comp, 1);
-            SpvId member = spirv_builder_emit_access_chain(&ctx->builder, 
ptr_type,
-                                                           
ctx->shared_block_var, &shared_offset, 1);
-            spirv_builder_emit_store(&ctx->builder, member, val);
-         }
-         num_writes--;
-      }
+   SpvId shared_block = get_shared_block(ctx, bit_size);
+   /* this is a partial write, so we have to loop and do a per-component write 
*/
+   u_foreach_bit(i, wrmask) {
+      SpvId shared_offset = emit_binop(ctx, SpvOpIAdd, 
spirv_builder_type_uint(&ctx->builder, 32), offset, emit_uint_const(ctx, 32, 
i));
+      SpvId val = src;
+      if (nir_src_num_components(intr->src[0]) != 1)
+         val = spirv_builder_emit_composite_extract(&ctx->builder, uint_type, 
src, &i, 1);
+      SpvId member = spirv_builder_emit_access_chain(&ctx->builder, ptr_type,
+                                                     shared_block, 
&shared_offset, 1);
+      spirv_builder_emit_store(&ctx->builder, member, val);
    }
 }
 
@@ -2698,15 +2705,17 @@ emit_deref_atomic_intrinsic(struct ntv_context *ctx, 
nir_intrinsic_instr *intr)
 static void
 emit_shared_atomic_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr 
*intr)
 {
-   SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint32);
+   unsigned bit_size = nir_src_bit_size(intr->src[1]);
+   SpvId dest_type = get_dest_type(ctx, &intr->dest, nir_type_uint);
    SpvId param = get_src(ctx, &intr->src[1]);
 
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                    SpvStorageClassWorkgroup,
                                                    dest_type);
    SpvId offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), 
get_src(ctx, &intr->src[0]), emit_uint_const(ctx, 32, 4));
+   SpvId shared_block = get_shared_block(ctx, bit_size);
    SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
-                                               ctx->shared_block_var, &offset, 
1);
+                                               shared_block, &offset, 1);
 
    SpvId param2 = 0;
 
@@ -4397,9 +4406,6 @@ nir_to_spirv(struct nir_shader *s, const struct 
zink_shader_info *sinfo, uint32_
                                            MAX2(s->info.gs.vertices_out, 1));
       break;
    case MESA_SHADER_COMPUTE:
-      if (s->info.shared_size)
-         create_shared_block(&ctx, s->info.shared_size);
-
       if (s->info.workgroup_size[0] || s->info.workgroup_size[1] || 
s->info.workgroup_size[2])
          spirv_builder_emit_exec_mode_literal3(&ctx.builder, entry_point, 
SpvExecutionModeLocalSize,
                                                
(uint32_t[3]){(uint32_t)s->info.workgroup_size[0], 
(uint32_t)s->info.workgroup_size[1],

Reply via email to