Module: Mesa
Branch: main
Commit: 9591c366669c80cfb41e8a6d95b032f37f4f25a7
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=9591c366669c80cfb41e8a6d95b032f37f4f25a7

Author: Rhys Perry <pendingchao...@gmail.com>
Date:   Tue Nov 14 20:26:44 2023 +0000

nir/loop_analyze: check min compatibility with comparison

Signed-off-by: Rhys Perry <pendingchao...@gmail.com>
Acked-by: Timothy Arceri <tarc...@itsqueeze.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26225>

---

 src/compiler/nir/nir_loop_analyze.c           |  42 +++++++-
 src/compiler/nir/tests/loop_analyze_tests.cpp | 144 ++++++++++++++++++++++++++
 2 files changed, 182 insertions(+), 4 deletions(-)

diff --git a/src/compiler/nir/nir_loop_analyze.c 
b/src/compiler/nir/nir_loop_analyze.c
index 4996f5f5325..3fb21441f24 100644
--- a/src/compiler/nir/nir_loop_analyze.c
+++ b/src/compiler/nir/nir_loop_analyze.c
@@ -671,15 +671,49 @@ guess_loop_limit(loop_info_state *state, nir_const_value 
*limit_val,
    return false;
 }
 
+static nir_op invert_comparison_if_needed(nir_op alu_op, bool invert);
+
+/* Returns whether "limit_op(a, b) alu_op c" is equivalent to "(a alu_op c) || 
(b alu_op c)". */
+static bool
+is_min_compatible(nir_op limit_op, nir_op alu_op, bool limit_rhs, bool 
invert_cond)
+{
+   switch (limit_op) {
+   case nir_op_imin:
+   case nir_op_fmin:
+      break;
+   default:
+      return false;
+   }
+
+   if (nir_op_infos[limit_op].input_types[0] != 
nir_op_infos[alu_op].input_types[0])
+      return false;
+
+   /* Comparisons we can split are:
+    * - min(a, b) < c
+    * - c >= min(a, b)
+    */
+   switch (invert_comparison_if_needed(alu_op, invert_cond)) {
+   case nir_op_ilt:
+   case nir_op_flt:
+      return !limit_rhs;
+   case nir_op_ige:
+   case nir_op_fge:
+      return limit_rhs;
+   default:
+      return false;
+   }
+}
+
 static bool
-try_find_limit_of_alu(nir_scalar limit, nir_const_value *limit_val,
-                      nir_loop_terminator *terminator, loop_info_state *state)
+try_find_limit_of_alu(nir_scalar limit, nir_const_value *limit_val, nir_op 
alu_op,
+                      bool invert_cond, nir_loop_terminator *terminator,
+                      loop_info_state *state)
 {
    if (!nir_scalar_is_alu(limit))
       return false;
 
    nir_op limit_op = nir_scalar_alu_op(limit);
-   if (limit_op == nir_op_imin || limit_op == nir_op_fmin) {
+   if (is_min_compatible(limit_op, alu_op, !terminator->induction_rhs, 
invert_cond)) {
       for (unsigned i = 0; i < 2; i++) {
          nir_scalar src = nir_scalar_chase_alu_src(limit, i);
          if (nir_scalar_is_const(src)) {
@@ -1308,7 +1342,7 @@ find_trip_count(loop_info_state *state, unsigned 
execution_mode,
       } else {
          trip_count_known = false;
 
-         if (!try_find_limit_of_alu(limit, &limit_val, terminator, state)) {
+         if (!try_find_limit_of_alu(limit, &limit_val, alu_op, invert_cond, 
terminator, state)) {
             /* Guess loop limit based on array access */
             if (!guess_loop_limit(state, &limit_val, basic_ind)) {
                terminator->exact_trip_count_unknown = true;
diff --git a/src/compiler/nir/tests/loop_analyze_tests.cpp 
b/src/compiler/nir/tests/loop_analyze_tests.cpp
index f7c3c41d750..acd574bad54 100644
--- a/src/compiler/nir/tests/loop_analyze_tests.cpp
+++ b/src/compiler/nir/tests/loop_analyze_tests.cpp
@@ -286,6 +286,28 @@ COMPARE_REVERSE(ishl)
 INOT_COMPARE(ilt_rev)
 INOT_COMPARE(ine)
 
+#define CMP_MIN(cmp, min)                                               \
+   static nir_def *nir_##cmp##_##min(nir_builder *b, nir_def *counter, nir_def 
*limit) \
+   {                                                                    \
+      nir_def *unk = nir_load_vertex_id(b);                             \
+      return nir_##cmp(b, counter, nir_##min(b, limit, unk));           \
+   }
+
+#define CMP_MIN_REV(cmp, min)                                           \
+   static nir_def *nir_##cmp##_##min##_rev(nir_builder *b, nir_def *counter, 
nir_def *limit) \
+   {                                                                    \
+      nir_def *unk = nir_load_vertex_id(b);                             \
+      return nir_##cmp(b, nir_##min(b, limit, unk), counter);           \
+   }
+
+CMP_MIN(ige, imin)
+CMP_MIN_REV(ige, imin)
+CMP_MIN(ige, fmin)
+CMP_MIN(uge, imin)
+CMP_MIN(ilt, imin)
+CMP_MIN_REV(ilt, imin)
+INOT_COMPARE(ilt_imin_rev)
+
 #define KNOWN_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr, 
count) \
    TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _known_count_ ## count)  
  \
    {                                                                    \
@@ -320,6 +342,40 @@ INOT_COMPARE(ine)
       }                                                                 \
    }
 
+#define INEXACT_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr, 
count) \
+   TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _inexact_count_ ## 
count)    \
+   {                                                                    \
+      nir_loop *loop =                                                  \
+         loop_builder(&b, {.init_value = _init_value,                   \
+                           .cond_value = _cond_value,                   \
+                           .incr_value = _incr_value,                   \
+                           .cond_instr = nir_ ## cond,                  \
+                           .incr_instr = nir_ ## incr});                \
+                                                                        \
+      nir_validate_shader(b.shader, "input");                           \
+                                                                        \
+      nir_loop_analyze_impl(b.impl, nir_var_all, false);                \
+                                                                        \
+      ASSERT_NE((void *)0, loop->info);                                 \
+      EXPECT_NE((void *)0, loop->info->limiting_terminator);            \
+      EXPECT_EQ(count, loop->info->max_trip_count);                     \
+      EXPECT_FALSE(loop->info->exact_trip_count_known);                 \
+                                                                        \
+      EXPECT_EQ(2, loop->info->num_induction_vars);                     \
+      ASSERT_NE((void *)0, loop->info->induction_vars);                 \
+                                                                        \
+      const nir_loop_induction_variable *const ivars =                  \
+         loop->info->induction_vars;                                    \
+                                                                        \
+      for (unsigned i = 0; i < loop->info->num_induction_vars; i++) {   \
+         EXPECT_NE((void *)0, ivars[i].def);                            \
+         ASSERT_NE((void *)0, ivars[i].init_src);                       \
+         EXPECT_TRUE(nir_src_is_const(*ivars[i].init_src));             \
+         ASSERT_NE((void *)0, ivars[i].update_src);                     \
+         EXPECT_TRUE(nir_src_is_const(ivars[i].update_src->src));       \
+      }                                                                 \
+   }
+
 #define UNKNOWN_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr) \
    TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _unknown_count)   \
    {                                                                    \
@@ -567,6 +623,16 @@ KNOWN_COUNT_TEST_INVERT(0x00000000, 0x00000001, 
0x00000006, ige, iadd, 5)
  */
 KNOWN_COUNT_TEST(0x0000000a, 0x00000005, 0xffffffff, inot_ilt_rev, iadd, 5)
 
+/*    int i = 10;
+ *    while (true) {
+ *       if (!(imin(vertex_id, 5) < i))
+ *          break;
+ *
+ *       i += -1;
+ *    }
+ */
+UNKNOWN_COUNT_TEST(0x0000000a, 0x00000005, 0xffffffff, inot_ilt_imin_rev, iadd)
+
 /*    uint i = 0;
  *    while (true) {
  *       if (i != 0)
@@ -1459,3 +1525,81 @@ KNOWN_COUNT_TEST_INVERT(0x0000007f, 0x00000003, 
0x00000001, ilt, imul, 16)
  *    }
  */
 KNOWN_COUNT_TEST_INVERT(0xffff7fff, 0x0000000f, 0x34cce9b0, ige, imul, 4)
+
+/*    int i = 0;
+ *    while (true) {
+ *       if (i >= imin(vertex_id, 4))
+ *          break;
+ *
+ *       i++;
+ *    }
+ */
+INEXACT_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, ige_imin, iadd, 4)
+
+/* This fmin is the wrong type to be useful.
+ *
+ *    int i = 0;
+ *    while (true) {
+ *       if (i >= fmin(vertex_id, 4))
+ *          break;
+ *
+ *       i++;
+ *    }
+ */
+UNKNOWN_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, ige_fmin, iadd)
+
+/* The comparison is unsigned, so this isn't safe if vertex_id is negative.
+ *
+ *    uint i = 0;
+ *    while (true) {
+ *       if (i >= imin(vertex_id, 4))
+ *          break;
+ *
+ *       i++;
+ *    }
+ */
+UNKNOWN_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, uge_imin, iadd)
+
+/*    int i = 8;
+ *    while (true) {
+ *       if (4 >= i)
+ *          break;
+ *
+ *       i += -1;
+ *    }
+ */
+KNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ige_rev, iadd, 4)
+
+/*    int i = 8;
+ *    while (true) {
+ *       if (i < 4)
+ *          break;
+ *
+ *       i += -1;
+ *    }
+ */
+KNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ilt, iadd, 5)
+
+/* This imin can increase the iteration count, not limit it.
+ *
+ *    int i = 8;
+ *    while (true) {
+ *       if (imin(vertex_id, 4) >= i)
+ *          break;
+ *
+ *       i += -1;
+ *    }
+ */
+UNKNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ige_imin_rev, iadd)
+
+/* This imin can increase the iteration count, not limit it.
+ *
+ *    int i = 8;
+ *    while (true) {
+ *       if (i < imin(vertex_id, 4))
+ *          break;
+ *
+ *       i += -1;
+ *    }
+ */
+UNKNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ilt_imin, iadd)

Reply via email to