Module: Mesa
Branch: main
Commit: 5e7c828c0e4c81a36ecf52766183f228ae0c042c
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=5e7c828c0e4c81a36ecf52766183f228ae0c042c

Author: Bas Nieuwenhuizen <[email protected]>
Date:   Sat Jul 15 19:49:49 2023 +0200

aco: Add WMMA instructions.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24683>

---

 src/amd/compiler/aco_insert_waitcnt.cpp            |  3 +-
 src/amd/compiler/aco_instruction_selection.cpp     | 43 ++++++++++++++++++++++
 .../compiler/aco_instruction_selection_setup.cpp   |  3 +-
 src/amd/compiler/aco_ir.h                          |  1 +
 src/amd/compiler/aco_opcodes.py                    |  7 ++++
 src/amd/compiler/aco_optimizer.cpp                 | 16 +++++++-
 src/amd/compiler/aco_statistics.cpp                |  5 +++
 src/amd/compiler/aco_validate.cpp                  |  5 ++-
 8 files changed, 77 insertions(+), 6 deletions(-)

diff --git a/src/amd/compiler/aco_insert_waitcnt.cpp 
b/src/amd/compiler/aco_insert_waitcnt.cpp
index 6e292a9bb72..56b4613847a 100644
--- a/src/amd/compiler/aco_insert_waitcnt.cpp
+++ b/src/amd/compiler/aco_insert_waitcnt.cpp
@@ -875,7 +875,8 @@ gen_alu(Instruction* instr, wait_ctx& ctx)
       for (const Definition& def : instr->definitions)
          insert_wait_entry(ctx, def, event, 0, cycle_info.latency);
    }
-   update_alu(ctx, is_valu, is_trans, clear, cycle_info.issue_cycles);
+   update_alu(ctx, is_valu && instr_info.classes[(int)instr->opcode] != 
instr_class::wmma, is_trans,
+              clear, cycle_info.issue_cycles);
 }
 
 void
diff --git a/src/amd/compiler/aco_instruction_selection.cpp 
b/src/amd/compiler/aco_instruction_selection.cpp
index 357d566e07e..8dd18f72043 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -8087,6 +8087,48 @@ create_fs_dual_src_export_gfx11(isel_context* ctx, const 
struct aco_export_mrt*
    ctx->program->has_color_exports = true;
 }
 
+static void
+visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr)
+{
+   aco_opcode opcode = aco_opcode::num_opcodes;
+   unsigned signed_mask = 0;
+   bool clamp = false;
+
+   switch (instr->src[0].ssa->bit_size) {
+   case 16:
+      switch (instr->def.bit_size) {
+      case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_f16; break;
+      case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break;
+      }
+      break;
+   case 8:
+      opcode = aco_opcode::v_wmma_i32_16x16x16_iu8;
+      signed_mask = nir_intrinsic_cmat_signed_mask(instr);
+      clamp = nir_intrinsic_saturate(instr);
+      break;
+   }
+
+   if (opcode == aco_opcode::num_opcodes)
+      unreachable("visit_cmat_muladd: invalid bit size combination");
+
+   Builder bld(ctx->program, ctx->block);
+
+   Temp dst = get_ssa_temp(ctx, &instr->def);
+   Operand A(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[0].ssa)));
+   Operand B(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[1].ssa)));
+   Operand C(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[2].ssa)));
+
+   A.setLateKill(true);
+   B.setLateKill(true);
+
+   VALU_instruction& vop3p = bld.vop3p(opcode, Definition(dst), A, B, C, 0, 
0)->valu();
+   vop3p.neg_lo[0] = (signed_mask & 0x1) != 0;
+   vop3p.neg_lo[1] = (signed_mask & 0x2) != 0;
+   vop3p.clamp = clamp;
+
+   emit_split_vector(ctx, dst, instr->def.num_components);
+}
+
 void
 visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
 {
@@ -9174,6 +9216,7 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* 
instr)
          bld.pseudo(aco_opcode::p_pops_gfx9_ordered_section_done);
       break;
    }
+   case nir_intrinsic_cmat_muladd_amd: visit_cmat_muladd(ctx, instr); break;
    default:
       isel_err(&instr->instr, "Unimplemented intrinsic instr");
       abort();
diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp 
b/src/amd/compiler/aco_instruction_selection_setup.cpp
index e7e56549205..4e27cdc8bcf 100644
--- a/src/amd/compiler/aco_instruction_selection_setup.cpp
+++ b/src/amd/compiler/aco_instruction_selection_setup.cpp
@@ -535,7 +535,8 @@ init_context(isel_context* ctx, nir_shader* shader)
                case nir_intrinsic_bvh64_intersect_ray_amd:
                case nir_intrinsic_load_vector_arg_amd:
                case nir_intrinsic_load_rt_dynamic_callable_stack_base_amd:
-               case nir_intrinsic_ordered_xfb_counter_add_amd: type = 
RegType::vgpr; break;
+               case nir_intrinsic_ordered_xfb_counter_add_amd:
+               case nir_intrinsic_cmat_muladd_amd: type = RegType::vgpr; break;
                case nir_intrinsic_load_shared:
                case nir_intrinsic_load_shared2_amd:
                   /* When the result of these loads is only used by cross-lane 
instructions,
diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h
index d521c4d929a..ca9c2efac5b 100644
--- a/src/amd/compiler/aco_ir.h
+++ b/src/amd/compiler/aco_ir.h
@@ -133,6 +133,7 @@ enum class instr_class : uint8_t {
    vmem = 17,
    waitcnt = 18,
    other = 19,
+   wmma = 20,
    count,
 };
 
diff --git a/src/amd/compiler/aco_opcodes.py b/src/amd/compiler/aco_opcodes.py
index b5b22926af9..71bc53a06c9 100644
--- a/src/amd/compiler/aco_opcodes.py
+++ b/src/amd/compiler/aco_opcodes.py
@@ -48,6 +48,7 @@ class InstrClass(Enum):
    VMem = 17
    Waitcnt = 18
    Other = 19
+   WMMA = 20
 
 class Format(Enum):
    PSEUDO = 0
@@ -1051,6 +1052,12 @@ opcode("v_dot8_i32_iu4", -1, -1, -1, 0x18, Format.VOP3P, 
InstrClass.Valu32)
 opcode("v_dot8_u32_u4", -1, 0x2b, 0x19, 0x19, Format.VOP3P, InstrClass.Valu32)
 opcode("v_dot2_f32_f16", -1, 0x23, 0x13, 0x13, Format.VOP3P, InstrClass.Valu32)
 opcode("v_dot2_f32_bf16", -1, -1, -1, 0x1a, Format.VOP3P, InstrClass.Valu32)
+opcode("v_wmma_f32_16x16x16_f16", -1, -1, -1, 0x40, Format.VOP3P, 
InstrClass.WMMA, False, False)
+opcode("v_wmma_f32_16x16x16_bf16", -1, -1, -1, 0x41, Format.VOP3P, 
InstrClass.WMMA, False, False)
+opcode("v_wmma_f16_16x16x16_f16", -1, -1, -1, 0x42, Format.VOP3P, 
InstrClass.WMMA, False, False)
+opcode("v_wmma_bf16_16x16x16_bf16", -1, -1, -1, 0x43, Format.VOP3P, 
InstrClass.WMMA, False, False)
+opcode("v_wmma_i32_16x16x16_iu8", -1, -1, -1, 0x44, Format.VOP3P, 
InstrClass.WMMA, False, False)
+opcode("v_wmma_i32_16x16x16_iu4", -1, -1, -1, 0x45, Format.VOP3P, 
InstrClass.WMMA, False, False)
 
 
 # VINTRP (GFX6 - GFX10.3) instructions:
diff --git a/src/amd/compiler/aco_optimizer.cpp 
b/src/amd/compiler/aco_optimizer.cpp
index 311b8ad4d04..9fdbffc7994 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -643,7 +643,13 @@ can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
           instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
           instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
           instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
-          instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg;
+          instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
+          instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
+          instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
+          instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
+          instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
+          instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
+          instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
 }
 
 bool
@@ -697,7 +703,13 @@ alu_can_accept_constant(const aco_ptr<Instruction>& instr, 
unsigned operand)
    case aco_opcode::v_interp_p10_f16_f32_inreg:
    case aco_opcode::v_interp_p2_f16_f32_inreg:
    case aco_opcode::v_interp_p10_rtz_f16_f32_inreg:
-   case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return false;
+   case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
+   case aco_opcode::v_wmma_f32_16x16x16_f16:
+   case aco_opcode::v_wmma_f32_16x16x16_bf16:
+   case aco_opcode::v_wmma_f16_16x16x16_f16:
+   case aco_opcode::v_wmma_bf16_16x16x16_bf16:
+   case aco_opcode::v_wmma_i32_16x16x16_iu8:
+   case aco_opcode::v_wmma_i32_16x16x16_iu4: return false;
    default: return true;
    }
 }
diff --git a/src/amd/compiler/aco_statistics.cpp 
b/src/amd/compiler/aco_statistics.cpp
index 9b0ee8a24ba..89303a43809 100644
--- a/src/amd/compiler/aco_statistics.cpp
+++ b/src/amd/compiler/aco_statistics.cpp
@@ -223,6 +223,11 @@ get_perf_info(const Program& program, const Instruction& 
instr)
                                                : perf_info{0, WAIT_USE(lds, 
1)};
       case instr_class::exp: return {0, WAIT_USE(export_gds, 1)};
       case instr_class::vmem: return {0, WAIT_USE(vmem, 1)};
+      case instr_class::wmma: {
+         /* int8 and (b)f16 have the same performance. */
+         uint8_t cost = instr.opcode == aco_opcode::v_wmma_i32_16x16x16_iu4 ? 
16 : 32;
+         return {cost, WAIT_USE(valu, cost)};
+      }
       case instr_class::barrier:
       case instr_class::waitcnt:
       case instr_class::other:
diff --git a/src/amd/compiler/aco_validate.cpp 
b/src/amd/compiler/aco_validate.cpp
index 4a43cdf1136..f344b48d43f 100644
--- a/src/amd/compiler/aco_validate.cpp
+++ b/src/amd/compiler/aco_validate.cpp
@@ -259,8 +259,9 @@ validate_ir(Program* program)
                   check(!vop3p.opsel_lo[i] && !vop3p.opsel_hi[i],
                         "Unexpected opsel for subdword operand", instr.get());
             }
-            check(instr->definitions[0].regClass() == v1, "VOP3P must have v1 
definition",
-                  instr.get());
+            check(instr->definitions[0].regClass() == v1 ||
+                     instr_info.classes[(int)instr->opcode] == 
instr_class::wmma,
+                  "VOP3P must have v1 definition", instr.get());
          }
 
          /* check for undefs */

Reply via email to