Module: Mesa Branch: main Commit: 9591c366669c80cfb41e8a6d95b032f37f4f25a7 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=9591c366669c80cfb41e8a6d95b032f37f4f25a7
Author: Rhys Perry <pendingchao...@gmail.com> Date: Tue Nov 14 20:26:44 2023 +0000 nir/loop_analyze: check min compatibility with comparison Signed-off-by: Rhys Perry <pendingchao...@gmail.com> Acked-by: Timothy Arceri <tarc...@itsqueeze.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26225> --- src/compiler/nir/nir_loop_analyze.c | 42 +++++++- src/compiler/nir/tests/loop_analyze_tests.cpp | 144 ++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 4 deletions(-) diff --git a/src/compiler/nir/nir_loop_analyze.c b/src/compiler/nir/nir_loop_analyze.c index 4996f5f5325..3fb21441f24 100644 --- a/src/compiler/nir/nir_loop_analyze.c +++ b/src/compiler/nir/nir_loop_analyze.c @@ -671,15 +671,49 @@ guess_loop_limit(loop_info_state *state, nir_const_value *limit_val, return false; } +static nir_op invert_comparison_if_needed(nir_op alu_op, bool invert); + +/* Returns whether "limit_op(a, b) alu_op c" is equivalent to "(a alu_op c) || (b alu_op c)". */ +static bool +is_min_compatible(nir_op limit_op, nir_op alu_op, bool limit_rhs, bool invert_cond) +{ + switch (limit_op) { + case nir_op_imin: + case nir_op_fmin: + break; + default: + return false; + } + + if (nir_op_infos[limit_op].input_types[0] != nir_op_infos[alu_op].input_types[0]) + return false; + + /* Comparisons we can split are: + * - min(a, b) < c + * - c >= min(a, b) + */ + switch (invert_comparison_if_needed(alu_op, invert_cond)) { + case nir_op_ilt: + case nir_op_flt: + return !limit_rhs; + case nir_op_ige: + case nir_op_fge: + return limit_rhs; + default: + return false; + } +} + static bool -try_find_limit_of_alu(nir_scalar limit, nir_const_value *limit_val, - nir_loop_terminator *terminator, loop_info_state *state) +try_find_limit_of_alu(nir_scalar limit, nir_const_value *limit_val, nir_op alu_op, + bool invert_cond, nir_loop_terminator *terminator, + loop_info_state *state) { if (!nir_scalar_is_alu(limit)) return false; nir_op limit_op = nir_scalar_alu_op(limit); - if (limit_op == nir_op_imin || limit_op == nir_op_fmin) { + if (is_min_compatible(limit_op, alu_op, !terminator->induction_rhs, invert_cond)) { for (unsigned i = 0; i < 2; i++) { nir_scalar src = nir_scalar_chase_alu_src(limit, i); if (nir_scalar_is_const(src)) { @@ -1308,7 +1342,7 @@ find_trip_count(loop_info_state *state, unsigned execution_mode, } else { trip_count_known = false; - if (!try_find_limit_of_alu(limit, &limit_val, terminator, state)) { + if (!try_find_limit_of_alu(limit, &limit_val, alu_op, invert_cond, terminator, state)) { /* Guess loop limit based on array access */ if (!guess_loop_limit(state, &limit_val, basic_ind)) { terminator->exact_trip_count_unknown = true; diff --git a/src/compiler/nir/tests/loop_analyze_tests.cpp b/src/compiler/nir/tests/loop_analyze_tests.cpp index f7c3c41d750..acd574bad54 100644 --- a/src/compiler/nir/tests/loop_analyze_tests.cpp +++ b/src/compiler/nir/tests/loop_analyze_tests.cpp @@ -286,6 +286,28 @@ COMPARE_REVERSE(ishl) INOT_COMPARE(ilt_rev) INOT_COMPARE(ine) +#define CMP_MIN(cmp, min) \ + static nir_def *nir_##cmp##_##min(nir_builder *b, nir_def *counter, nir_def *limit) \ + { \ + nir_def *unk = nir_load_vertex_id(b); \ + return nir_##cmp(b, counter, nir_##min(b, limit, unk)); \ + } + +#define CMP_MIN_REV(cmp, min) \ + static nir_def *nir_##cmp##_##min##_rev(nir_builder *b, nir_def *counter, nir_def *limit) \ + { \ + nir_def *unk = nir_load_vertex_id(b); \ + return nir_##cmp(b, nir_##min(b, limit, unk), counter); \ + } + +CMP_MIN(ige, imin) +CMP_MIN_REV(ige, imin) +CMP_MIN(ige, fmin) +CMP_MIN(uge, imin) +CMP_MIN(ilt, imin) +CMP_MIN_REV(ilt, imin) +INOT_COMPARE(ilt_imin_rev) + #define KNOWN_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr, count) \ TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _known_count_ ## count) \ { \ @@ -320,6 +342,40 @@ INOT_COMPARE(ine) } \ } +#define INEXACT_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr, count) \ + TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _inexact_count_ ## count) \ + { \ + nir_loop *loop = \ + loop_builder(&b, {.init_value = _init_value, \ + .cond_value = _cond_value, \ + .incr_value = _incr_value, \ + .cond_instr = nir_ ## cond, \ + .incr_instr = nir_ ## incr}); \ + \ + nir_validate_shader(b.shader, "input"); \ + \ + nir_loop_analyze_impl(b.impl, nir_var_all, false); \ + \ + ASSERT_NE((void *)0, loop->info); \ + EXPECT_NE((void *)0, loop->info->limiting_terminator); \ + EXPECT_EQ(count, loop->info->max_trip_count); \ + EXPECT_FALSE(loop->info->exact_trip_count_known); \ + \ + EXPECT_EQ(2, loop->info->num_induction_vars); \ + ASSERT_NE((void *)0, loop->info->induction_vars); \ + \ + const nir_loop_induction_variable *const ivars = \ + loop->info->induction_vars; \ + \ + for (unsigned i = 0; i < loop->info->num_induction_vars; i++) { \ + EXPECT_NE((void *)0, ivars[i].def); \ + ASSERT_NE((void *)0, ivars[i].init_src); \ + EXPECT_TRUE(nir_src_is_const(*ivars[i].init_src)); \ + ASSERT_NE((void *)0, ivars[i].update_src); \ + EXPECT_TRUE(nir_src_is_const(ivars[i].update_src->src)); \ + } \ + } + #define UNKNOWN_COUNT_TEST(_init_value, _cond_value, _incr_value, cond, incr) \ TEST_F(nir_loop_analyze_test, incr ## _ ## cond ## _unknown_count) \ { \ @@ -567,6 +623,16 @@ KNOWN_COUNT_TEST_INVERT(0x00000000, 0x00000001, 0x00000006, ige, iadd, 5) */ KNOWN_COUNT_TEST(0x0000000a, 0x00000005, 0xffffffff, inot_ilt_rev, iadd, 5) +/* int i = 10; + * while (true) { + * if (!(imin(vertex_id, 5) < i)) + * break; + * + * i += -1; + * } + */ +UNKNOWN_COUNT_TEST(0x0000000a, 0x00000005, 0xffffffff, inot_ilt_imin_rev, iadd) + /* uint i = 0; * while (true) { * if (i != 0) @@ -1459,3 +1525,81 @@ KNOWN_COUNT_TEST_INVERT(0x0000007f, 0x00000003, 0x00000001, ilt, imul, 16) * } */ KNOWN_COUNT_TEST_INVERT(0xffff7fff, 0x0000000f, 0x34cce9b0, ige, imul, 4) + +/* int i = 0; + * while (true) { + * if (i >= imin(vertex_id, 4)) + * break; + * + * i++; + * } + */ +INEXACT_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, ige_imin, iadd, 4) + +/* This fmin is the wrong type to be useful. + * + * int i = 0; + * while (true) { + * if (i >= fmin(vertex_id, 4)) + * break; + * + * i++; + * } + */ +UNKNOWN_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, ige_fmin, iadd) + +/* The comparison is unsigned, so this isn't safe if vertex_id is negative. + * + * uint i = 0; + * while (true) { + * if (i >= imin(vertex_id, 4)) + * break; + * + * i++; + * } + */ +UNKNOWN_COUNT_TEST(0x00000000, 0x00000004, 0x00000001, uge_imin, iadd) + +/* int i = 8; + * while (true) { + * if (4 >= i) + * break; + * + * i += -1; + * } + */ +KNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ige_rev, iadd, 4) + +/* int i = 8; + * while (true) { + * if (i < 4) + * break; + * + * i += -1; + * } + */ +KNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ilt, iadd, 5) + +/* This imin can increase the iteration count, not limit it. + * + * int i = 8; + * while (true) { + * if (imin(vertex_id, 4) >= i) + * break; + * + * i += -1; + * } + */ +UNKNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ige_imin_rev, iadd) + +/* This imin can increase the iteration count, not limit it. + * + * int i = 8; + * while (true) { + * if (i < imin(vertex_id, 4)) + * break; + * + * i += -1; + * } + */ +UNKNOWN_COUNT_TEST(0x00000008, 0x00000004, 0xffffffff, ilt_imin, iadd)