Module: Mesa Branch: main Commit: 5e7c828c0e4c81a36ecf52766183f228ae0c042c URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=5e7c828c0e4c81a36ecf52766183f228ae0c042c
Author: Bas Nieuwenhuizen <[email protected]> Date: Sat Jul 15 19:49:49 2023 +0200 aco: Add WMMA instructions. Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24683> --- src/amd/compiler/aco_insert_waitcnt.cpp | 3 +- src/amd/compiler/aco_instruction_selection.cpp | 43 ++++++++++++++++++++++ .../compiler/aco_instruction_selection_setup.cpp | 3 +- src/amd/compiler/aco_ir.h | 1 + src/amd/compiler/aco_opcodes.py | 7 ++++ src/amd/compiler/aco_optimizer.cpp | 16 +++++++- src/amd/compiler/aco_statistics.cpp | 5 +++ src/amd/compiler/aco_validate.cpp | 5 ++- 8 files changed, 77 insertions(+), 6 deletions(-) diff --git a/src/amd/compiler/aco_insert_waitcnt.cpp b/src/amd/compiler/aco_insert_waitcnt.cpp index 6e292a9bb72..56b4613847a 100644 --- a/src/amd/compiler/aco_insert_waitcnt.cpp +++ b/src/amd/compiler/aco_insert_waitcnt.cpp @@ -875,7 +875,8 @@ gen_alu(Instruction* instr, wait_ctx& ctx) for (const Definition& def : instr->definitions) insert_wait_entry(ctx, def, event, 0, cycle_info.latency); } - update_alu(ctx, is_valu, is_trans, clear, cycle_info.issue_cycles); + update_alu(ctx, is_valu && instr_info.classes[(int)instr->opcode] != instr_class::wmma, is_trans, + clear, cycle_info.issue_cycles); } void diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 357d566e07e..8dd18f72043 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -8087,6 +8087,48 @@ create_fs_dual_src_export_gfx11(isel_context* ctx, const struct aco_export_mrt* ctx->program->has_color_exports = true; } +static void +visit_cmat_muladd(isel_context* ctx, nir_intrinsic_instr* instr) +{ + aco_opcode opcode = aco_opcode::num_opcodes; + unsigned signed_mask = 0; + bool clamp = false; + + switch (instr->src[0].ssa->bit_size) { + case 16: + switch (instr->def.bit_size) { + case 32: opcode = aco_opcode::v_wmma_f32_16x16x16_f16; break; + case 16: opcode = aco_opcode::v_wmma_f16_16x16x16_f16; break; + } + break; + case 8: + opcode = aco_opcode::v_wmma_i32_16x16x16_iu8; + signed_mask = nir_intrinsic_cmat_signed_mask(instr); + clamp = nir_intrinsic_saturate(instr); + break; + } + + if (opcode == aco_opcode::num_opcodes) + unreachable("visit_cmat_muladd: invalid bit size combination"); + + Builder bld(ctx->program, ctx->block); + + Temp dst = get_ssa_temp(ctx, &instr->def); + Operand A(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[0].ssa))); + Operand B(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[1].ssa))); + Operand C(as_vgpr(ctx, get_ssa_temp(ctx, instr->src[2].ssa))); + + A.setLateKill(true); + B.setLateKill(true); + + VALU_instruction& vop3p = bld.vop3p(opcode, Definition(dst), A, B, C, 0, 0)->valu(); + vop3p.neg_lo[0] = (signed_mask & 0x1) != 0; + vop3p.neg_lo[1] = (signed_mask & 0x2) != 0; + vop3p.clamp = clamp; + + emit_split_vector(ctx, dst, instr->def.num_components); +} + void visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr) { @@ -9174,6 +9216,7 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr) bld.pseudo(aco_opcode::p_pops_gfx9_ordered_section_done); break; } + case nir_intrinsic_cmat_muladd_amd: visit_cmat_muladd(ctx, instr); break; default: isel_err(&instr->instr, "Unimplemented intrinsic instr"); abort(); diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index e7e56549205..4e27cdc8bcf 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -535,7 +535,8 @@ init_context(isel_context* ctx, nir_shader* shader) case nir_intrinsic_bvh64_intersect_ray_amd: case nir_intrinsic_load_vector_arg_amd: case nir_intrinsic_load_rt_dynamic_callable_stack_base_amd: - case nir_intrinsic_ordered_xfb_counter_add_amd: type = RegType::vgpr; break; + case nir_intrinsic_ordered_xfb_counter_add_amd: + case nir_intrinsic_cmat_muladd_amd: type = RegType::vgpr; break; case nir_intrinsic_load_shared: case nir_intrinsic_load_shared2_amd: /* When the result of these loads is only used by cross-lane instructions, diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h index d521c4d929a..ca9c2efac5b 100644 --- a/src/amd/compiler/aco_ir.h +++ b/src/amd/compiler/aco_ir.h @@ -133,6 +133,7 @@ enum class instr_class : uint8_t { vmem = 17, waitcnt = 18, other = 19, + wmma = 20, count, }; diff --git a/src/amd/compiler/aco_opcodes.py b/src/amd/compiler/aco_opcodes.py index b5b22926af9..71bc53a06c9 100644 --- a/src/amd/compiler/aco_opcodes.py +++ b/src/amd/compiler/aco_opcodes.py @@ -48,6 +48,7 @@ class InstrClass(Enum): VMem = 17 Waitcnt = 18 Other = 19 + WMMA = 20 class Format(Enum): PSEUDO = 0 @@ -1051,6 +1052,12 @@ opcode("v_dot8_i32_iu4", -1, -1, -1, 0x18, Format.VOP3P, InstrClass.Valu32) opcode("v_dot8_u32_u4", -1, 0x2b, 0x19, 0x19, Format.VOP3P, InstrClass.Valu32) opcode("v_dot2_f32_f16", -1, 0x23, 0x13, 0x13, Format.VOP3P, InstrClass.Valu32) opcode("v_dot2_f32_bf16", -1, -1, -1, 0x1a, Format.VOP3P, InstrClass.Valu32) +opcode("v_wmma_f32_16x16x16_f16", -1, -1, -1, 0x40, Format.VOP3P, InstrClass.WMMA, False, False) +opcode("v_wmma_f32_16x16x16_bf16", -1, -1, -1, 0x41, Format.VOP3P, InstrClass.WMMA, False, False) +opcode("v_wmma_f16_16x16x16_f16", -1, -1, -1, 0x42, Format.VOP3P, InstrClass.WMMA, False, False) +opcode("v_wmma_bf16_16x16x16_bf16", -1, -1, -1, 0x43, Format.VOP3P, InstrClass.WMMA, False, False) +opcode("v_wmma_i32_16x16x16_iu8", -1, -1, -1, 0x44, Format.VOP3P, InstrClass.WMMA, False, False) +opcode("v_wmma_i32_16x16x16_iu4", -1, -1, -1, 0x45, Format.VOP3P, InstrClass.WMMA, False, False) # VINTRP (GFX6 - GFX10.3) instructions: diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 311b8ad4d04..9fdbffc7994 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -643,7 +643,13 @@ can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr) instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg && instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg && instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg && - instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg; + instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg && + instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 && + instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 && + instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 && + instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 && + instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 && + instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4; } bool @@ -697,7 +703,13 @@ alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand) case aco_opcode::v_interp_p10_f16_f32_inreg: case aco_opcode::v_interp_p2_f16_f32_inreg: case aco_opcode::v_interp_p10_rtz_f16_f32_inreg: - case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return false; + case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: + case aco_opcode::v_wmma_f32_16x16x16_f16: + case aco_opcode::v_wmma_f32_16x16x16_bf16: + case aco_opcode::v_wmma_f16_16x16x16_f16: + case aco_opcode::v_wmma_bf16_16x16x16_bf16: + case aco_opcode::v_wmma_i32_16x16x16_iu8: + case aco_opcode::v_wmma_i32_16x16x16_iu4: return false; default: return true; } } diff --git a/src/amd/compiler/aco_statistics.cpp b/src/amd/compiler/aco_statistics.cpp index 9b0ee8a24ba..89303a43809 100644 --- a/src/amd/compiler/aco_statistics.cpp +++ b/src/amd/compiler/aco_statistics.cpp @@ -223,6 +223,11 @@ get_perf_info(const Program& program, const Instruction& instr) : perf_info{0, WAIT_USE(lds, 1)}; case instr_class::exp: return {0, WAIT_USE(export_gds, 1)}; case instr_class::vmem: return {0, WAIT_USE(vmem, 1)}; + case instr_class::wmma: { + /* int8 and (b)f16 have the same performance. */ + uint8_t cost = instr.opcode == aco_opcode::v_wmma_i32_16x16x16_iu4 ? 16 : 32; + return {cost, WAIT_USE(valu, cost)}; + } case instr_class::barrier: case instr_class::waitcnt: case instr_class::other: diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp index 4a43cdf1136..f344b48d43f 100644 --- a/src/amd/compiler/aco_validate.cpp +++ b/src/amd/compiler/aco_validate.cpp @@ -259,8 +259,9 @@ validate_ir(Program* program) check(!vop3p.opsel_lo[i] && !vop3p.opsel_hi[i], "Unexpected opsel for subdword operand", instr.get()); } - check(instr->definitions[0].regClass() == v1, "VOP3P must have v1 definition", - instr.get()); + check(instr->definitions[0].regClass() == v1 || + instr_info.classes[(int)instr->opcode] == instr_class::wmma, + "VOP3P must have v1 definition", instr.get()); } /* check for undefs */
