Module: Mesa
Branch: main
Commit: c212a00ebfb3901df95beabd2bcce76715581060
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=c212a00ebfb3901df95beabd2bcce76715581060

Author: Mike Blumenkrantz <[email protected]>
Date:   Tue Sep  6 15:50:23 2022 -0400

zink: handle 64bit float atomics

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/18144>

---

 .../drivers/zink/nir_to_spirv/nir_to_spirv.c       | 64 ++++++++++++++++++++--
 src/gallium/drivers/zink/zink_compiler.c           | 16 ++++++
 2 files changed, 74 insertions(+), 6 deletions(-)

diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c 
b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
index 0ed25f39f4f..f7ceff7d660 100644
--- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
+++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
@@ -214,7 +214,7 @@ emit_access_decorations(struct ntv_context *ctx, 
nir_variable *var, SpvId var_id
 }
 
 static SpvOp
-get_atomic_op(nir_intrinsic_op op)
+get_atomic_op(struct ntv_context *ctx, unsigned bit_size, nir_intrinsic_op op)
 {
    switch (op) {
 #define CASE_ATOMIC_OP(type) \
@@ -222,6 +222,32 @@ get_atomic_op(nir_intrinsic_op op)
    case nir_intrinsic_image_deref_atomic_##type: \
    case nir_intrinsic_shared_atomic_##type
 
+#define ATOMIC_FCAP(NAME) \
+   do {\
+      if (bit_size == 16) \
+         spirv_builder_emit_cap(&ctx->builder, 
SpvCapabilityAtomicFloat16##NAME##EXT); \
+      if (bit_size == 32) \
+         spirv_builder_emit_cap(&ctx->builder, 
SpvCapabilityAtomicFloat32##NAME##EXT); \
+      if (bit_size == 64) \
+         spirv_builder_emit_cap(&ctx->builder, 
SpvCapabilityAtomicFloat64##NAME##EXT); \
+   } while (0)
+
+   CASE_ATOMIC_OP(fadd):
+      ATOMIC_FCAP(Add);
+      if (bit_size == 16)
+         spirv_builder_emit_extension(&ctx->builder, 
"SPV_EXT_shader_atomic_float16_add");
+      else
+         spirv_builder_emit_extension(&ctx->builder, 
"SPV_EXT_shader_atomic_float_add");
+      return SpvOpAtomicFAddEXT;
+   CASE_ATOMIC_OP(fmax):
+      ATOMIC_FCAP(MinMax);
+      spirv_builder_emit_extension(&ctx->builder, 
"SPV_EXT_shader_atomic_float_min_max");
+      return SpvOpAtomicFMaxEXT;
+   CASE_ATOMIC_OP(fmin):
+      ATOMIC_FCAP(MinMax);
+      spirv_builder_emit_extension(&ctx->builder, 
"SPV_EXT_shader_atomic_float_min_max");
+      return SpvOpAtomicFMinEXT;
+
    CASE_ATOMIC_OP(add):
       return SpvOpAtomicIAdd;
    CASE_ATOMIC_OP(umin):
@@ -248,7 +274,22 @@ get_atomic_op(nir_intrinsic_op op)
    }
    return 0;
 }
+
+static bool
+atomic_op_is_float(nir_intrinsic_op op)
+{
+   switch (op) {
+   CASE_ATOMIC_OP(fadd):
+   CASE_ATOMIC_OP(fmax):
+   CASE_ATOMIC_OP(fmin):
+      return true;
+   default:
+      break;
+   }
+   return false;
+}
 #undef CASE_ATOMIC_OP
+
 static SpvId
 emit_float_const(struct ntv_context *ctx, int bit_size, double value)
 {
@@ -2672,7 +2713,7 @@ static void
 handle_atomic_op(struct ntv_context *ctx, nir_intrinsic_instr *intr, SpvId 
ptr, SpvId param, SpvId param2, nir_alu_type type)
 {
    SpvId dest_type = get_dest_type(ctx, &intr->dest, type);
-   SpvId result = emit_atomic(ctx, get_atomic_op(intr->intrinsic), dest_type, 
ptr, param, param2);
+   SpvId result = emit_atomic(ctx, get_atomic_op(ctx, 
nir_dest_bit_size(intr->dest), intr->intrinsic), dest_type, ptr, param, param2);
    assert(result);
    store_dest(ctx, &intr->dest, result, type);
 }
@@ -2685,10 +2726,13 @@ emit_deref_atomic_intrinsic(struct ntv_context *ctx, 
nir_intrinsic_instr *intr)
 
    SpvId param2 = 0;
 
+   if (nir_src_bit_size(intr->src[1]) == 64)
+      spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics);
+
    if (intr->intrinsic == nir_intrinsic_deref_atomic_comp_swap)
       param2 = get_src(ctx, &intr->src[2]);
 
-   handle_atomic_op(ctx, intr, ptr, param, param2, nir_type_uint32);
+   handle_atomic_op(ctx, intr, ptr, param, param2, 
atomic_op_is_float(intr->intrinsic) ? nir_type_float : nir_type_uint32);
 }
 
 static void
@@ -2701,17 +2745,18 @@ emit_shared_atomic_intrinsic(struct ntv_context *ctx, 
nir_intrinsic_instr *intr)
    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
                                                    SpvStorageClassWorkgroup,
                                                    dest_type);
-   SpvId offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), 
get_src(ctx, &intr->src[0]), emit_uint_const(ctx, 32, 4));
+   SpvId offset = emit_binop(ctx, SpvOpUDiv, get_uvec_type(ctx, 32, 1), 
get_src(ctx, &intr->src[0]), emit_uint_const(ctx, 32, bit_size / 8));
    SpvId shared_block = get_shared_block(ctx, bit_size);
    SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
                                                shared_block, &offset, 1);
-
+   if (nir_src_bit_size(intr->src[1]) == 64)
+      spirv_builder_emit_cap(&ctx->builder, SpvCapabilityInt64Atomics);
    SpvId param2 = 0;
 
    if (intr->intrinsic == nir_intrinsic_shared_atomic_comp_swap)
       param2 = get_src(ctx, &intr->src[2]);
 
-   handle_atomic_op(ctx, intr, ptr, param, param2, nir_type_uint32);
+   handle_atomic_op(ctx, intr, ptr, param, param2, 
atomic_op_is_float(intr->intrinsic) ? nir_type_float : nir_type_uint32);
 }
 
 static void
@@ -3183,6 +3228,10 @@ emit_intrinsic(struct ntv_context *ctx, 
nir_intrinsic_instr *intr)
                                         SpvMemorySemanticsAcquireReleaseMask);
       break;
 
+   case nir_intrinsic_deref_atomic_fadd:
+   case nir_intrinsic_deref_atomic_fmin:
+   case nir_intrinsic_deref_atomic_fmax:
+   case nir_intrinsic_deref_atomic_fcomp_swap:
    case nir_intrinsic_deref_atomic_add:
    case nir_intrinsic_deref_atomic_umin:
    case nir_intrinsic_deref_atomic_imin:
@@ -3196,6 +3245,9 @@ emit_intrinsic(struct ntv_context *ctx, 
nir_intrinsic_instr *intr)
       emit_deref_atomic_intrinsic(ctx, intr);
       break;
 
+   case nir_intrinsic_shared_atomic_fadd:
+   case nir_intrinsic_shared_atomic_fmin:
+   case nir_intrinsic_shared_atomic_fmax:
    case nir_intrinsic_shared_atomic_add:
    case nir_intrinsic_shared_atomic_umin:
    case nir_intrinsic_shared_atomic_imin:
diff --git a/src/gallium/drivers/zink/zink_compiler.c 
b/src/gallium/drivers/zink/zink_compiler.c
index dce7d53d339..3ac3b0bb269 100644
--- a/src/gallium/drivers/zink/zink_compiler.c
+++ b/src/gallium/drivers/zink/zink_compiler.c
@@ -1218,6 +1218,18 @@ rewrite_atomic_ssbo_instr(nir_builder *b, nir_instr 
*instr, struct bo_vars *bo)
    nir_intrinsic_op op;
    nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
    switch (intr->intrinsic) {
+   case nir_intrinsic_ssbo_atomic_fadd:
+      op = nir_intrinsic_deref_atomic_fadd;
+      break;
+   case nir_intrinsic_ssbo_atomic_fmin:
+      op = nir_intrinsic_deref_atomic_fmin;
+      break;
+   case nir_intrinsic_ssbo_atomic_fmax:
+      op = nir_intrinsic_deref_atomic_fmax;
+      break;
+   case nir_intrinsic_ssbo_atomic_fcomp_swap:
+      op = nir_intrinsic_deref_atomic_fcomp_swap;
+      break;
    case nir_intrinsic_ssbo_atomic_add:
       op = nir_intrinsic_deref_atomic_add;
       break;
@@ -1297,6 +1309,10 @@ remove_bo_access_instr(nir_builder *b, nir_instr *instr, 
void *data)
    nir_src *src;
    bool ssbo = true;
    switch (intr->intrinsic) {
+   case nir_intrinsic_ssbo_atomic_fadd:
+   case nir_intrinsic_ssbo_atomic_fmin:
+   case nir_intrinsic_ssbo_atomic_fmax:
+   case nir_intrinsic_ssbo_atomic_fcomp_swap:
    case nir_intrinsic_ssbo_atomic_add:
    case nir_intrinsic_ssbo_atomic_umin:
    case nir_intrinsic_ssbo_atomic_imin:

Reply via email to