Module: Mesa Branch: main Commit: e86ab8173b573e0754d0fa85f818dca0d85c4526 URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=e86ab8173b573e0754d0fa85f818dca0d85c4526
Author: Rhys Perry <pendingchao...@gmail.com> Date: Fri Nov 17 12:24:14 2023 +0000 nir/algebraic: optimize vkd3d-proton's MSAD Signed-off-by: Rhys Perry <pendingchao...@gmail.com> Reviewed-by: Georg Lehmann <dadschoo...@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26907> --- src/compiler/nir/nir_opt_algebraic.py | 19 ++++++++++++++ src/compiler/nir/tests/algebraic_tests.cpp | 40 ++++++++++++++++++++++++++++++ src/compiler/nir/tests/nir_test.h | 2 +- 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 88293c995c0..55ae777a6f1 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -2622,6 +2622,25 @@ optimizations += [ (vkd3d_proton_packed_f2f16_rtz_lo(('fneg', 'x'), ('fabs', 'x')), ('pack_half_2x16_rtz_split', ('fneg', 'x'), 0)), ] +def vkd3d_proton_msad(): + pattern = None + for i in range(4): + ref = ('extract_u8', 'a@32', i) + src = ('extract_u8', 'b@32', i) + sad = ('iabs', ('iadd', ref, ('ineg', src))) + msad = ('bcsel', ('ieq', ref, 0), 0, sad) + if pattern == None: + pattern = msad + else: + pattern = ('iadd', pattern, msad) + pattern = (pattern[0] + '(many-comm-expr)', *pattern[1:]) + return pattern + +optimizations += [ + (vkd3d_proton_msad(), ('msad_4x8', a, b, 0), 'options->has_msad'), + (('iadd', ('msad_4x8', a, b, 0), c), ('msad_4x8', a, b, c)), +] + # "all_equal(eq(a, b), vec(~0))" is the same as "all_equal(a, b)" # "any_nequal(neq(a, b), vec(0))" is the same as "any_nequal(a, b)" diff --git a/src/compiler/nir/tests/algebraic_tests.cpp b/src/compiler/nir/tests/algebraic_tests.cpp index 65e15a57da4..a6fcc660efa 100644 --- a/src/compiler/nir/tests/algebraic_tests.cpp +++ b/src/compiler/nir/tests/algebraic_tests.cpp @@ -132,6 +132,46 @@ TEST_F(nir_opt_algebraic_test, irem_pow2_src2) test_2src_op(nir_op_irem, INT32_MIN, -4); } +TEST_F(nir_opt_algebraic_test, msad) +{ + options.lower_bitfield_extract = true; + options.has_bfe = true; + options.has_msad = true; + + nir_def *src0 = nir_load_var(b, nir_local_variable_create(b->impl, glsl_int_type(), "src0")); + nir_def *src1 = nir_load_var(b, nir_local_variable_create(b->impl, glsl_int_type(), "src1")); + + /* This mimics the sequence created by vkd3d-proton. */ + nir_def *res = NULL; + for (unsigned i = 0; i < 4; i++) { + nir_def *ref = nir_ubitfield_extract(b, src0, nir_imm_int(b, i * 8), nir_imm_int(b, 8)); + nir_def *src = nir_ubitfield_extract(b, src1, nir_imm_int(b, i * 8), nir_imm_int(b, 8)); + nir_def *is_ref_zero = nir_ieq_imm(b, ref, 0); + nir_def *abs_diff = nir_iabs(b, nir_isub(b, ref, src)); + nir_def *masked_diff = nir_bcsel(b, is_ref_zero, nir_imm_int(b, 0), abs_diff); + if (res) + res = nir_iadd(b, res, masked_diff); + else + res = masked_diff; + } + + nir_store_var(b, res_var, res, 0x1); + + while (nir_opt_algebraic(b->shader)) { + nir_opt_constant_folding(b->shader); + nir_opt_dce(b->shader); + } + + unsigned count = 0; + nir_foreach_instr(instr, nir_start_block(b->impl)) { + if (instr->type == nir_instr_type_alu) { + ASSERT_TRUE(nir_instr_as_alu(instr)->op == nir_op_msad_4x8); + ASSERT_EQ(count, 0); + count++; + } + } +} + TEST_F(nir_opt_idiv_const_test, umod) { for (uint32_t d : {16u, 17u, 0u, UINT32_MAX}) { diff --git a/src/compiler/nir/tests/nir_test.h b/src/compiler/nir/tests/nir_test.h index c94fce23902..151071424e6 100644 --- a/src/compiler/nir/tests/nir_test.h +++ b/src/compiler/nir/tests/nir_test.h @@ -17,7 +17,6 @@ class nir_test : public ::testing::Test { { glsl_type_singleton_init_or_ref(); - static const nir_shader_compiler_options options = {}; _b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, &options, "%s", name); b = &_b; } @@ -34,6 +33,7 @@ class nir_test : public ::testing::Test { glsl_type_singleton_decref(); } + nir_shader_compiler_options options = {}; nir_builder _b; nir_builder *b; };