Module: Mesa
Branch: main
Commit: 79c8740c6e758bef29d43fa352d7b5f4668f78a8
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=79c8740c6e758bef29d43fa352d7b5f4668f78a8

Author: Rhys Perry <[email protected]>
Date:   Wed Feb 23 11:33:16 2022 +0000

aco: fix fp16 opcode definitions

The v_fma_mix optimizations assume v_cvt_f16_f32 and v_mul_f16 use a v2b
definition.

Signed-off-by: Rhys Perry <[email protected]>
Reviewed-by: Daniel Schürmann <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/14769>

---

 src/amd/compiler/aco_instruction_selection.cpp | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/src/amd/compiler/aco_instruction_selection.cpp 
b/src/amd/compiler/aco_instruction_selection.cpp
index ea4518c2943..ebcb0a63410 100644
--- a/src/amd/compiler/aco_instruction_selection.cpp
+++ b/src/amd/compiler/aco_instruction_selection.cpp
@@ -2520,7 +2520,7 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
       aco_ptr<Instruction> norm;
       if (dst.regClass() == v2b) {
          Temp half_pi = bld.copy(bld.def(s1), Operand::c32(0x3118u));
-         Temp tmp = bld.vop2(aco_opcode::v_mul_f16, bld.def(v1), half_pi, src);
+         Temp tmp = bld.vop2(aco_opcode::v_mul_f16, bld.def(v2b), half_pi, 
src);
          aco_opcode opcode =
             instr->op == nir_op_fsin ? aco_opcode::v_sin_f16 : 
aco_opcode::v_cos_f16;
          bld.vop1(opcode, Definition(dst), tmp);
@@ -3334,7 +3334,7 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
    }
    case nir_op_fquantize2f16: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      Temp f16 = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), src);
+      Temp f16 = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v2b), src);
       Temp f32, cmp_res;
 
       if (ctx->program->chip_class >= GFX8) {
@@ -7642,7 +7642,8 @@ emit_addition_uniform_reduce(isel_context* ctx, nir_op 
op, Definition dst, nir_s
 
    if (op == nir_op_fadd) {
       src_tmp = as_vgpr(ctx, src_tmp);
-      Temp tmp = dst.regClass() == s1 ? bld.tmp(src_tmp.regClass()) : 
dst.getTemp();
+      Temp tmp = dst.regClass() == s1 ? bld.tmp(RegClass::get(RegType::vgpr, 
src.ssa->bit_size / 8))
+                                      : dst.getTemp();
 
       if (src.ssa->bit_size == 16) {
          count = bld.vop1(aco_opcode::v_cvt_f16_u16, bld.def(v2b), count);

Reply via email to