Module: Mesa
Branch: main
Commit: 0477421f7db6007e9490fba449655af9f107918b
URL:    
http://cgit.freedesktop.org/mesa/mesa/commit/?id=0477421f7db6007e9490fba449655af9f107918b

Author: Rhys Perry <pendingchao...@gmail.com>
Date:   Fri Nov 17 11:21:19 2023 +0000

nir: add msad_4x8

Signed-off-by: Rhys Perry <pendingchao...@gmail.com>
Reviewed-by: Georg Lehmann <dadschoo...@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26907>

---

 src/compiler/nir/nir.h                       |  3 +++
 src/compiler/nir/nir_constant_expressions.py | 12 ++++++++++++
 src/compiler/nir/nir_opcodes.py              | 23 ++++++++++++++++++-----
 src/compiler/nir/nir_range_analysis.c        |  3 ++-
 4 files changed, 35 insertions(+), 6 deletions(-)

diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 0aac0382480..61c2b4bc884 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -3932,6 +3932,9 @@ typedef struct nir_shader_compiler_options {
    /** Backend supports uclz. */
    bool has_uclz;
 
+   /** Backend support msad_u4x8. */
+   bool has_msad;
+
    /**
     * Is this the Intel vec4 backend?
     *
diff --git a/src/compiler/nir/nir_constant_expressions.py 
b/src/compiler/nir/nir_constant_expressions.py
index ac302ba3165..50ea8756da6 100644
--- a/src/compiler/nir/nir_constant_expressions.py
+++ b/src/compiler/nir/nir_constant_expressions.py
@@ -375,6 +375,18 @@ static uint32_t pack_2x16_to_unorm_10_2(uint32_t src0)
    return vfmul_v3d(vfsat_v3d(src0), 0x000303ff);
 }
 
+static uint32_t
+msad(uint32_t src0, uint32_t src1, uint32_t src2) {
+   uint32_t res = src2;
+   for (unsigned i = 0; i < 4; i++) {
+      const uint8_t ref = src0 >> (i * 8);
+      const uint8_t src = src1 >> (i * 8);
+      if (ref != 0)
+         res += MAX2(ref, src) - MIN2(ref, src);
+   }
+   return res;
+}
+
 /* Some typed vector structures to make things like src0.y work */
 typedef int8_t int1_t;
 typedef uint8_t uint1_t;
diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py
index da53ea56f43..66ce98e46cf 100644
--- a/src/compiler/nir/nir_opcodes.py
+++ b/src/compiler/nir/nir_opcodes.py
@@ -1126,11 +1126,6 @@ if (bits == 0) {
 }
 """)
 
-# Sum of absolute differences with accumulation.
-# (Equivalent to AMD's v_sad_u8 instruction.)
-# The first two sources contain packed 8-bit unsigned integers, the instruction
-# will calculate the absolute difference of these, and then add them together.
-# There is also a third source which is a 32-bit unsigned integer and added to 
the result.
 triop_horiz("sad_u8x4", 1, 1, 1, 1, """
 uint8_t s0_b0 = (src0.x & 0x000000ff) >> 0;
 uint8_t s0_b1 = (src0.x & 0x0000ff00) >> 8;
@@ -1147,6 +1142,24 @@ dst.x = src2.x +
         (s0_b1 > s1_b1 ? (s0_b1 - s1_b1) : (s1_b1 - s0_b1)) +
         (s0_b2 > s1_b2 ? (s0_b2 - s1_b2) : (s1_b2 - s0_b2)) +
         (s0_b3 > s1_b3 ? (s0_b3 - s1_b3) : (s1_b3 - s0_b3));
+""", description = """
+Sum of absolute differences with accumulation. Equivalent to AMD's v_sad_u8 
instruction.
+
+The first two sources contain packed 8-bit unsigned integers, the instruction 
will
+calculate the absolute difference of these, and then add them together. There 
is also a
+third source which is a 32-bit unsigned integer and added to the result.
+""")
+
+triop("msad_4x8", tuint32, "", """
+dst = msad(src0, src1, src2);
+""", description = """
+Masked sum of absolute differences with accumulation. Equivalent to AMD's 
v_msad_u8
+instruction and DXIL's MSAD.
+
+The first two sources contain packed 8-bit unsigned integers, the instruction
+will calculate the absolute difference of integers when src0's is non-zero, and
+then add them together. There is also a third source which is a 32-bit unsigned
+integer and added to the result.
 """)
 
 # Combines the first component of each input to make a 3-component vector.
diff --git a/src/compiler/nir/nir_range_analysis.c 
b/src/compiler/nir/nir_range_analysis.c
index d982c9e602c..0ac5638f695 100644
--- a/src/compiler/nir/nir_range_analysis.c
+++ b/src/compiler/nir/nir_range_analysis.c
@@ -1865,7 +1865,8 @@ get_alu_uub(struct analysis_state *state, struct 
uub_query q, uint32_t *result,
       *result = 1;
       break;
    case nir_op_sad_u8x4:
-      *result = src[2] + 4 * 255;
+   case nir_op_msad_4x8:
+      *result = MIN2((uint64_t)src[2] + 4 * 255, UINT32_MAX);
       break;
    case nir_op_extract_u8:
       *result = MIN2(src[0], UINT8_MAX);

Reply via email to