Module: Mesa Branch: main Commit: 0477421f7db6007e9490fba449655af9f107918b URL: http://cgit.freedesktop.org/mesa/mesa/commit/?id=0477421f7db6007e9490fba449655af9f107918b
Author: Rhys Perry <pendingchao...@gmail.com> Date: Fri Nov 17 11:21:19 2023 +0000 nir: add msad_4x8 Signed-off-by: Rhys Perry <pendingchao...@gmail.com> Reviewed-by: Georg Lehmann <dadschoo...@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26907> --- src/compiler/nir/nir.h | 3 +++ src/compiler/nir/nir_constant_expressions.py | 12 ++++++++++++ src/compiler/nir/nir_opcodes.py | 23 ++++++++++++++++++----- src/compiler/nir/nir_range_analysis.c | 3 ++- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 0aac0382480..61c2b4bc884 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3932,6 +3932,9 @@ typedef struct nir_shader_compiler_options { /** Backend supports uclz. */ bool has_uclz; + /** Backend support msad_u4x8. */ + bool has_msad; + /** * Is this the Intel vec4 backend? * diff --git a/src/compiler/nir/nir_constant_expressions.py b/src/compiler/nir/nir_constant_expressions.py index ac302ba3165..50ea8756da6 100644 --- a/src/compiler/nir/nir_constant_expressions.py +++ b/src/compiler/nir/nir_constant_expressions.py @@ -375,6 +375,18 @@ static uint32_t pack_2x16_to_unorm_10_2(uint32_t src0) return vfmul_v3d(vfsat_v3d(src0), 0x000303ff); } +static uint32_t +msad(uint32_t src0, uint32_t src1, uint32_t src2) { + uint32_t res = src2; + for (unsigned i = 0; i < 4; i++) { + const uint8_t ref = src0 >> (i * 8); + const uint8_t src = src1 >> (i * 8); + if (ref != 0) + res += MAX2(ref, src) - MIN2(ref, src); + } + return res; +} + /* Some typed vector structures to make things like src0.y work */ typedef int8_t int1_t; typedef uint8_t uint1_t; diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index da53ea56f43..66ce98e46cf 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -1126,11 +1126,6 @@ if (bits == 0) { } """) -# Sum of absolute differences with accumulation. -# (Equivalent to AMD's v_sad_u8 instruction.) -# The first two sources contain packed 8-bit unsigned integers, the instruction -# will calculate the absolute difference of these, and then add them together. -# There is also a third source which is a 32-bit unsigned integer and added to the result. triop_horiz("sad_u8x4", 1, 1, 1, 1, """ uint8_t s0_b0 = (src0.x & 0x000000ff) >> 0; uint8_t s0_b1 = (src0.x & 0x0000ff00) >> 8; @@ -1147,6 +1142,24 @@ dst.x = src2.x + (s0_b1 > s1_b1 ? (s0_b1 - s1_b1) : (s1_b1 - s0_b1)) + (s0_b2 > s1_b2 ? (s0_b2 - s1_b2) : (s1_b2 - s0_b2)) + (s0_b3 > s1_b3 ? (s0_b3 - s1_b3) : (s1_b3 - s0_b3)); +""", description = """ +Sum of absolute differences with accumulation. Equivalent to AMD's v_sad_u8 instruction. + +The first two sources contain packed 8-bit unsigned integers, the instruction will +calculate the absolute difference of these, and then add them together. There is also a +third source which is a 32-bit unsigned integer and added to the result. +""") + +triop("msad_4x8", tuint32, "", """ +dst = msad(src0, src1, src2); +""", description = """ +Masked sum of absolute differences with accumulation. Equivalent to AMD's v_msad_u8 +instruction and DXIL's MSAD. + +The first two sources contain packed 8-bit unsigned integers, the instruction +will calculate the absolute difference of integers when src0's is non-zero, and +then add them together. There is also a third source which is a 32-bit unsigned +integer and added to the result. """) # Combines the first component of each input to make a 3-component vector. diff --git a/src/compiler/nir/nir_range_analysis.c b/src/compiler/nir/nir_range_analysis.c index d982c9e602c..0ac5638f695 100644 --- a/src/compiler/nir/nir_range_analysis.c +++ b/src/compiler/nir/nir_range_analysis.c @@ -1865,7 +1865,8 @@ get_alu_uub(struct analysis_state *state, struct uub_query q, uint32_t *result, *result = 1; break; case nir_op_sad_u8x4: - *result = src[2] + 4 * 255; + case nir_op_msad_4x8: + *result = MIN2((uint64_t)src[2] + 4 * 255, UINT32_MAX); break; case nir_op_extract_u8: *result = MIN2(src[0], UINT8_MAX);