Module: Mesa
Branch: main
Commit: e86ab8173b573e0754d0fa85f818dca0d85c4526
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=e86ab8173b573e0754d0fa85f818dca0d85c4526

Author: Rhys Perry <pendingchao...@gmail.com>
Date:   Fri Nov 17 12:24:14 2023 +0000

nir/algebraic: optimize vkd3d-proton's MSAD

Signed-off-by: Rhys Perry <pendingchao...@gmail.com>
Reviewed-by: Georg Lehmann <dadschoo...@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26907>

---

 src/compiler/nir/nir_opt_algebraic.py      | 19 ++++++++++++++
 src/compiler/nir/tests/algebraic_tests.cpp | 40 ++++++++++++++++++++++++++++++
 src/compiler/nir/tests/nir_test.h          |  2 +-
 3 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/src/compiler/nir/nir_opt_algebraic.py 
b/src/compiler/nir/nir_opt_algebraic.py
index 88293c995c0..55ae777a6f1 100644
--- a/src/compiler/nir/nir_opt_algebraic.py
+++ b/src/compiler/nir/nir_opt_algebraic.py
@@ -2622,6 +2622,25 @@ optimizations += [
    (vkd3d_proton_packed_f2f16_rtz_lo(('fneg', 'x'), ('fabs', 'x')), 
('pack_half_2x16_rtz_split', ('fneg', 'x'), 0)),
 ]
 
+def vkd3d_proton_msad():
+   pattern = None
+   for i in range(4):
+      ref = ('extract_u8', 'a@32', i)
+      src = ('extract_u8', 'b@32', i)
+      sad = ('iabs', ('iadd', ref, ('ineg', src)))
+      msad = ('bcsel', ('ieq', ref, 0), 0, sad)
+      if pattern == None:
+         pattern = msad
+      else:
+         pattern = ('iadd', pattern, msad)
+   pattern = (pattern[0] + '(many-comm-expr)', *pattern[1:])
+   return pattern
+
+optimizations += [
+   (vkd3d_proton_msad(), ('msad_4x8', a, b, 0), 'options->has_msad'),
+   (('iadd', ('msad_4x8', a, b, 0), c), ('msad_4x8', a, b, c)),
+]
+
 
 # "all_equal(eq(a, b), vec(~0))" is the same as "all_equal(a, b)"
 # "any_nequal(neq(a, b), vec(0))" is the same as "any_nequal(a, b)"
diff --git a/src/compiler/nir/tests/algebraic_tests.cpp 
b/src/compiler/nir/tests/algebraic_tests.cpp
index 65e15a57da4..a6fcc660efa 100644
--- a/src/compiler/nir/tests/algebraic_tests.cpp
+++ b/src/compiler/nir/tests/algebraic_tests.cpp
@@ -132,6 +132,46 @@ TEST_F(nir_opt_algebraic_test, irem_pow2_src2)
    test_2src_op(nir_op_irem, INT32_MIN, -4);
 }
 
+TEST_F(nir_opt_algebraic_test, msad)
+{
+   options.lower_bitfield_extract = true;
+   options.has_bfe = true;
+   options.has_msad = true;
+
+   nir_def *src0 = nir_load_var(b, nir_local_variable_create(b->impl, 
glsl_int_type(), "src0"));
+   nir_def *src1 = nir_load_var(b, nir_local_variable_create(b->impl, 
glsl_int_type(), "src1"));
+
+   /* This mimics the sequence created by vkd3d-proton. */
+   nir_def *res = NULL;
+   for (unsigned i = 0; i < 4; i++) {
+      nir_def *ref = nir_ubitfield_extract(b, src0, nir_imm_int(b, i * 8), 
nir_imm_int(b, 8));
+      nir_def *src = nir_ubitfield_extract(b, src1, nir_imm_int(b, i * 8), 
nir_imm_int(b, 8));
+      nir_def *is_ref_zero = nir_ieq_imm(b, ref, 0);
+      nir_def *abs_diff = nir_iabs(b, nir_isub(b, ref, src));
+      nir_def *masked_diff = nir_bcsel(b, is_ref_zero, nir_imm_int(b, 0), 
abs_diff);
+      if (res)
+         res = nir_iadd(b, res, masked_diff);
+      else
+         res = masked_diff;
+   }
+
+   nir_store_var(b, res_var, res, 0x1);
+
+   while (nir_opt_algebraic(b->shader)) {
+      nir_opt_constant_folding(b->shader);
+      nir_opt_dce(b->shader);
+   }
+
+   unsigned count = 0;
+   nir_foreach_instr(instr, nir_start_block(b->impl)) {
+      if (instr->type == nir_instr_type_alu) {
+         ASSERT_TRUE(nir_instr_as_alu(instr)->op == nir_op_msad_4x8);
+         ASSERT_EQ(count, 0);
+         count++;
+      }
+   }
+}
+
 TEST_F(nir_opt_idiv_const_test, umod)
 {
    for (uint32_t d : {16u, 17u, 0u, UINT32_MAX}) {
diff --git a/src/compiler/nir/tests/nir_test.h 
b/src/compiler/nir/tests/nir_test.h
index c94fce23902..151071424e6 100644
--- a/src/compiler/nir/tests/nir_test.h
+++ b/src/compiler/nir/tests/nir_test.h
@@ -17,7 +17,6 @@ class nir_test : public ::testing::Test {
    {
       glsl_type_singleton_init_or_ref();
 
-      static const nir_shader_compiler_options options = {};
       _b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE, &options, "%s", 
name);
       b = &_b;
    }
@@ -34,6 +33,7 @@ class nir_test : public ::testing::Test {
       glsl_type_singleton_decref();
    }
 
+   nir_shader_compiler_options options = {};
    nir_builder _b;
    nir_builder *b;
 };

Reply via email to