Module: Mesa
Branch: main
Commit: f2e41eda9ee53897805ab02a1a7915258d3f3768
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=f2e41eda9ee53897805ab02a1a7915258d3f3768

Author: Timur Kristóf <[email protected]>
Date:   Thu Sep  9 08:38:41 2021 +0200

aco: Add ability to optimize v_lshl + v_sub into v_mad_i32_i24.

Also change combine_add_lshl to use check_vop3_operands instead
of its own checks of the operands.

Signed-off-by: Timur Kristóf <[email protected]>
Reviewed-by: Rhys Perry <[email protected]>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12786>

---

 src/amd/compiler/aco_optimizer.cpp        | 58 ++++++++++++++++++++++---------
 src/amd/compiler/tests/test_optimizer.cpp |  3 +-
 2 files changed, 42 insertions(+), 19 deletions(-)

diff --git a/src/amd/compiler/aco_optimizer.cpp 
b/src/amd/compiler/aco_optimizer.cpp
index dd04b3584d6..7b22d76b64a 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -2937,14 +2937,27 @@ combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
 }
 
 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
- * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c) */
+ * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
+ * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
+ * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
+ */
 bool
-combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
+combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
 {
    if (instr->usesModifiers())
       return false;
 
-   for (unsigned i = 0; i < 2; i++) {
+   /* Substractions: start at operand 1 to avoid mixup such as
+    * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
+    */
+   unsigned start_op_idx = is_sub ? 1 : 0;
+
+   /* Don't allow 24-bit operands on subtraction because
+    * v_mad_i32_i24 applies a sign extension.
+    */
+   bool allow_24bit = !is_sub;
+
+   for (unsigned i = start_op_idx; i < 2; i++) {
       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
       if (!op_instr)
          continue;
@@ -2953,25 +2966,32 @@ combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
           op_instr->opcode != aco_opcode::v_lshlrev_b32)
          continue;
 
-      if (op_instr->opcode == aco_opcode::v_lshlrev_b32 && 
op_instr->operands[1].isTemp() &&
-          op_instr->operands[1].getTemp().type() == RegType::sgpr && 
instr->operands[!i].isTemp() &&
-          instr->operands[!i].getTemp().type() == RegType::sgpr)
-         return false;
-
       int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
+
       if (op_instr->operands[shift_op_idx].isConstant() &&
-          op_instr->operands[shift_op_idx].constantValue() <= 6 && /* no 
literals */
-          (op_instr->operands[!shift_op_idx].is24bit() ||
+          ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
            op_instr->operands[!shift_op_idx].is16bit())) {
-         uint32_t multiplier = 1 << 
op_instr->operands[shift_op_idx].constantValue();
+         uint32_t multiplier = 1 << 
(op_instr->operands[shift_op_idx].constantValue() % 32u);
+         if (is_sub)
+            multiplier = -multiplier;
+         if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
+            continue;
+
+         Operand ops[3] = {
+            op_instr->operands[!shift_op_idx],
+            Operand::c32(multiplier),
+            instr->operands[!i],
+         };
+         if (!check_vop3_operands(ctx, 3, ops))
+            return false;
 
          ctx.uses[instr->operands[i].tempId()]--;
 
+         aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : 
aco_opcode::v_mad_u32_u24;
          aco_ptr<VOP3_instruction> new_instr{
-            create_instruction<VOP3_instruction>(aco_opcode::v_mad_u32_u24, 
Format::VOP3, 3, 1)};
-         new_instr->operands[0] = op_instr->operands[!shift_op_idx];
-         new_instr->operands[1] = Operand::c32(multiplier);
-         new_instr->operands[2] = instr->operands[!i];
+            create_instruction<VOP3_instruction>(mad_op, Format::VOP3, 3, 1)};
+         for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
+            new_instr->operands[op_idx] = ops[op_idx];
          new_instr->definitions[0] = instr->definitions[0];
          instr = std::move(new_instr);
          ctx.info[instr->definitions[0].tempId()].label = 0;
@@ -3432,11 +3452,15 @@ combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& 
instr)
       } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
       } else if (!carry_out && combine_three_valu_op(ctx, instr, 
aco_opcode::v_mul_u32_u24,
                                                      
aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
-      } else if (!carry_out && combine_add_lshl(ctx, instr)) {
+      } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
       }
    } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == 
aco_opcode::v_sub_co_u32 ||
               instr->opcode == aco_opcode::v_sub_co_u32_e64) {
-      combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2);
+      bool carry_out = instr->opcode != aco_opcode::v_sub_u32 &&
+                       ctx.uses[instr->definitions[1].tempId()] > 0;
+      if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
+      } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
+      }
    } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
               instr->opcode == aco_opcode::v_subrev_co_u32 ||
               instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
diff --git a/src/amd/compiler/tests/test_optimizer.cpp 
b/src/amd/compiler/tests/test_optimizer.cpp
index 9609fea4f2b..31a229f99e9 100644
--- a/src/amd/compiler/tests/test_optimizer.cpp
+++ b/src/amd/compiler/tests/test_optimizer.cpp
@@ -802,8 +802,7 @@ BEGIN_TEST(optimize.add_lshlrev)
       lshl = bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), 
Operand::c32(4u), a_16bit);
       writeout(4, bld.vadd32(bld.def(v1), lshl, Operand(inputs[1])));
 
-      //~gfx8! v1: %lshl5 = v_lshlrev_b32 4, (is24bit)%c
-      //~gfx8! v1: %res5, s2: %_ = v_add_co_u32 %c, %lshl5
+      //~gfx8! v1: %res5 = v_mad_u32_u24 (is24bit)%c, 16, %c
       //~gfx(9|10)! v1: %res5 = v_lshl_add_u32 (is24bit)%c, 4, %c
       //! p_unit_test 5, %res5
       Operand c_24bit = Operand(inputs[2]);

Reply via email to