Module: Mesa
Branch: main
Commit: b6c2a5d48dd71b433315f81b429907a239cf309d
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=b6c2a5d48dd71b433315f81b429907a239cf309d

Author: Rhys Perry <pendingchao...@gmail.com>
Date:   Fri Oct 13 16:48:36 2023 +0100

nir/loop_analyze: fix vector basis/limit/comparison

Signed-off-by: Rhys Perry <pendingchao...@gmail.com>
Acked-by: Timothy Arceri <tarc...@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26225>

---

 src/compiler/nir/nir_loop_analyze.c | 31 ++++++++++++-------------------
 1 file changed, 12 insertions(+), 19 deletions(-)

diff --git a/src/compiler/nir/nir_loop_analyze.c 
b/src/compiler/nir/nir_loop_analyze.c
index e06a8f61392..4996f5f5325 100644
--- a/src/compiler/nir/nir_loop_analyze.c
+++ b/src/compiler/nir/nir_loop_analyze.c
@@ -911,8 +911,8 @@ get_iteration(nir_op cond_op, nir_const_value initial, 
nir_const_value step,
 }
 
 static int32_t
-get_iteration_empirical(nir_alu_instr *cond_alu, nir_alu_instr *incr_alu,
-                        nir_def *basis, nir_const_value initial,
+get_iteration_empirical(nir_scalar cond, nir_alu_instr *incr_alu,
+                        nir_scalar basis, nir_const_value initial,
                         bool invert_cond, unsigned execution_mode,
                         unsigned max_unroll_iterations)
 {
@@ -920,14 +920,12 @@ get_iteration_empirical(nir_alu_instr *cond_alu, 
nir_alu_instr *incr_alu,
    nir_const_value result;
    nir_const_value iter = initial;
 
-   const nir_scalar original = nir_get_scalar(basis, 0);
-   const nir_scalar cond = nir_get_scalar(&cond_alu->def, 0);
-   const nir_scalar incr = nir_get_scalar(&incr_alu->def, 0);
+   const nir_scalar incr = nir_get_scalar(&incr_alu->def, basis.comp);
 
    while (iter_count <= max_unroll_iterations) {
       bool success;
 
-      success = try_eval_const_alu(&result, cond, &original, &iter,
+      success = try_eval_const_alu(&result, cond, &basis, &iter,
                                    1, execution_mode);
       if (!success)
          return -1;
@@ -938,7 +936,7 @@ get_iteration_empirical(nir_alu_instr *cond_alu, 
nir_alu_instr *incr_alu,
 
       iter_count++;
 
-      success = try_eval_const_alu(&result, incr, &original, &iter,
+      success = try_eval_const_alu(&result, incr, &basis, &iter,
                                    1, execution_mode);
       assert(success);
 
@@ -949,17 +947,16 @@ get_iteration_empirical(nir_alu_instr *cond_alu, 
nir_alu_instr *incr_alu,
 }
 
 static bool
-will_break_on_first_iteration(nir_alu_instr *cond_alu, nir_def *basis,
-                              nir_def *limit_basis,
+will_break_on_first_iteration(nir_scalar cond, nir_scalar basis,
+                              nir_scalar limit_basis,
                               nir_const_value initial, nir_const_value limit,
                               bool invert_cond, unsigned execution_mode)
 {
    nir_const_value result;
 
-   const nir_scalar originals[2] = { nir_get_scalar(basis, 0), 
nir_get_scalar(limit_basis, 0) };
+   const nir_scalar originals[2] = { basis, limit_basis };
    const nir_const_value replacements[2] = { initial, limit };
 
-   const nir_scalar cond = nir_get_scalar(&cond_alu->def, 0);
    ASSERTED bool success = try_eval_const_alu(&result, cond, originals,
                                               replacements, 2, execution_mode);
 
@@ -1018,7 +1015,7 @@ test_iterations(int32_t iter_int, nir_const_value step,
 }
 
 static int
-calculate_iterations(nir_def *basis, nir_def *limit_basis,
+calculate_iterations(nir_scalar basis, nir_scalar limit_basis,
                      nir_const_value initial, nir_const_value step,
                      nir_const_value limit, nir_alu_instr *alu,
                      nir_scalar cond, nir_op alu_op, bool limit_rhs,
@@ -1041,10 +1038,6 @@ calculate_iterations(nir_def *basis, nir_def 
*limit_basis,
              induction_base_type);
    }
 
-   if (cond.def->num_components != 1 || basis->num_components != 1 ||
-       limit_basis->num_components != 1)
-      return -1;
-
    /* do-while loops can increment the starting value before the condition is
     * checked. e.g.
     *
@@ -1069,7 +1062,7 @@ calculate_iterations(nir_def *basis, nir_def *limit_basis,
     * however if the loop condition is false on the first iteration
     * get_iteration's assumption is broken. Handle such loops first.
     */
-   if (will_break_on_first_iteration(cond_alu, basis, limit_basis, initial,
+   if (will_break_on_first_iteration(cond, basis, limit_basis, initial,
                                      limit, invert_cond, execution_mode)) {
       return 0;
    }
@@ -1101,7 +1094,7 @@ calculate_iterations(nir_def *basis, nir_def *limit_basis,
    case nir_op_ishl:
    case nir_op_ishr:
    case nir_op_ushr:
-      return get_iteration_empirical(cond_alu, alu, basis, initial,
+      return get_iteration_empirical(cond, alu, basis, initial,
                                      invert_cond, execution_mode,
                                      max_unroll_iterations);
    default:
@@ -1356,7 +1349,7 @@ find_trip_count(loop_info_state *state, unsigned 
execution_mode,
       nir_const_value initial_val = nir_scalar_as_const_value(initial_s);
       nir_const_value step_val = nir_scalar_as_const_value(alu_s);
 
-      int iterations = calculate_iterations(lv->basis, limit.def,
+      int iterations = calculate_iterations(nir_get_scalar(lv->basis, 
basic_ind.comp), limit,
                                             initial_val, step_val, limit_val,
                                             
nir_instr_as_alu(nir_src_parent_instr(&lv->update_src->src)),
                                             cond,

Reply via email to