The widening 16-bit multiply + pairwise add pattern in the Neon DCT paths is a good fit for the SVE 16-bit dot-product instructions. This patch adds an SVE implementation of the 32x32 DCT path.
Relative performance compared to the Neon implementation: Neoverse-V1: 1.13x Neoverse-V2: 1.44x Neoverse-N2: 1.55x --- source/common/aarch64/dct-prim-sve.cpp | 203 ++++++++++++++++++++++++ source/common/aarch64/neon-sve-bridge.h | 5 + 2 files changed, 208 insertions(+) diff --git a/source/common/aarch64/dct-prim-sve.cpp b/source/common/aarch64/dct-prim-sve.cpp index 3f6de3bff..75bc8d359 100644 --- a/source/common/aarch64/dct-prim-sve.cpp +++ b/source/common/aarch64/dct-prim-sve.cpp @@ -239,6 +239,191 @@ static inline void partialButterfly16_sve(const int16_t *src, int16_t *dst) } } +template<int shift> +static inline void partialButterfly32_sve(const int16_t *src, int16_t *dst) +{ + const int line = 32; + + int16x8_t O[line][2]; + int16x8_t EO[line]; + int32x4_t EEO[line]; + int32x4_t EEEE[line / 2]; + int32x4_t EEEO[line / 2]; + + for (int i = 0; i < line; i += 2) + { + int16x8x4_t in_lo = vld1q_s16_x4(src + (i + 0) * line); + in_lo.val[2] = rev16(in_lo.val[2]); + in_lo.val[3] = rev16(in_lo.val[3]); + + int16x8x4_t in_hi = vld1q_s16_x4(src + (i + 1) * line); + in_hi.val[2] = rev16(in_hi.val[2]); + in_hi.val[3] = rev16(in_hi.val[3]); + + int32x4_t E0[4]; + E0[0] = vaddl_s16(vget_low_s16(in_lo.val[0]), + vget_low_s16(in_lo.val[3])); + E0[1] = vaddl_s16(vget_high_s16(in_lo.val[0]), + vget_high_s16(in_lo.val[3])); + E0[2] = vaddl_s16(vget_low_s16(in_lo.val[1]), + vget_low_s16(in_lo.val[2])); + E0[3] = vaddl_s16(vget_high_s16(in_lo.val[1]), + vget_high_s16(in_lo.val[2])); + + int32x4_t E1[4]; + E1[0] = vaddl_s16(vget_low_s16(in_hi.val[0]), + vget_low_s16(in_hi.val[3])); + E1[1] = vaddl_s16(vget_high_s16(in_hi.val[0]), + vget_high_s16(in_hi.val[3])); + E1[2] = vaddl_s16(vget_low_s16(in_hi.val[1]), + vget_low_s16(in_hi.val[2])); + E1[3] = vaddl_s16(vget_high_s16(in_hi.val[1]), + vget_high_s16(in_hi.val[2])); + + O[i + 0][0] = vsubq_s16(in_lo.val[0], in_lo.val[3]); + O[i + 0][1] = vsubq_s16(in_lo.val[1], in_lo.val[2]); + + O[i + 1][0] = vsubq_s16(in_hi.val[0], in_hi.val[3]); + O[i + 1][1] = vsubq_s16(in_hi.val[1], in_hi.val[2]); + + int32x4_t EE0[2]; + E0[3] = rev32(E0[3]); + E0[2] = rev32(E0[2]); + EE0[0] = vaddq_s32(E0[0], E0[3]); + EE0[1] = vaddq_s32(E0[1], E0[2]); + EO[i + 0] = vcombine_s16(vmovn_s32(vsubq_s32(E0[0], E0[3])), + vmovn_s32(vsubq_s32(E0[1], E0[2]))); + + int32x4_t EE1[2]; + E1[3] = rev32(E1[3]); + E1[2] = rev32(E1[2]); + EE1[0] = vaddq_s32(E1[0], E1[3]); + EE1[1] = vaddq_s32(E1[1], E1[2]); + EO[i + 1] = vcombine_s16(vmovn_s32(vsubq_s32(E1[0], E1[3])), + vmovn_s32(vsubq_s32(E1[1], E1[2]))); + + int32x4_t EEE0; + EE0[1] = rev32(EE0[1]); + EEE0 = vaddq_s32(EE0[0], EE0[1]); + EEO[i + 0] = vsubq_s32(EE0[0], EE0[1]); + + int32x4_t EEE1; + EE1[1] = rev32(EE1[1]); + EEE1 = vaddq_s32(EE1[0], EE1[1]); + EEO[i + 1] = vsubq_s32(EE1[0], EE1[1]); + + int32x4_t t0 = vreinterpretq_s32_s64( + vzip1q_s64(vreinterpretq_s64_s32(EEE0), + vreinterpretq_s64_s32(EEE1))); + int32x4_t t1 = vrev64q_s32(vreinterpretq_s32_s64( + vzip2q_s64(vreinterpretq_s64_s32(EEE0), + vreinterpretq_s64_s32(EEE1)))); + + EEEE[i / 2] = vaddq_s32(t0, t1); + EEEO[i / 2] = vsubq_s32(t0, t1); + } + + for (int k = 1; k < 32; k += 2) + { + int16_t *d = dst + k * line; + + int16x8_t c0_c1 = vld1q_s16(&g_t32[k][0]); + int16x8_t c2_c3 = vld1q_s16(&g_t32[k][8]); + + for (int i = 0; i < line; i += 4) + { + int64x2_t t0 = x265_sdotq_s16(vdupq_n_s64(0), c0_c1, O[i + 0][0]); + int64x2_t t1 = x265_sdotq_s16(vdupq_n_s64(0), c0_c1, O[i + 1][0]); + int64x2_t t2 = x265_sdotq_s16(vdupq_n_s64(0), c0_c1, O[i + 2][0]); + int64x2_t t3 = x265_sdotq_s16(vdupq_n_s64(0), c0_c1, O[i + 3][0]); + t0 = x265_sdotq_s16(t0, c2_c3, O[i + 0][1]); + t1 = x265_sdotq_s16(t1, c2_c3, O[i + 1][1]); + t2 = x265_sdotq_s16(t2, c2_c3, O[i + 2][1]); + t3 = x265_sdotq_s16(t3, c2_c3, O[i + 3][1]); + + int32x4_t t01 = vcombine_s32(vmovn_s64(t0), vmovn_s64(t1)); + int32x4_t t23 = vcombine_s32(vmovn_s64(t2), vmovn_s64(t3)); + int16x4_t res = vrshrn_n_s32(vpaddq_s32(t01, t23), shift); + vst1_s16(d, res); + + d += 4; + } + } + + for (int k = 2; k < 32; k += 4) + { + int16_t *d = dst + k * line; + + int16x8_t c0 = vld1q_s16(&g_t32[k][0]); + + for (int i = 0; i < line; i += 4) + { + int64x2_t t0 = x265_sdotq_s16(vdupq_n_s64(0), c0, EO[i + 0]); + int64x2_t t1 = x265_sdotq_s16(vdupq_n_s64(0), c0, EO[i + 1]); + int64x2_t t2 = x265_sdotq_s16(vdupq_n_s64(0), c0, EO[i + 2]); + int64x2_t t3 = x265_sdotq_s16(vdupq_n_s64(0), c0, EO[i + 3]); + + int32x4_t t01 = vcombine_s32(vmovn_s64(t0), vmovn_s64(t1)); + int32x4_t t23 = vcombine_s32(vmovn_s64(t2), vmovn_s64(t3)); + int16x4_t res = vrshrn_n_s32(vpaddq_s32(t01, t23), shift); + vst1_s16(d, res); + + d += 4; + } + } + + for (int k = 4; k < 32; k += 8) + { + int16_t *d = dst + k * line; + + int32x4_t c = x265_vld1sh_s32(&g_t32[k][0]); + + for (int i = 0; i < line; i += 4) + { + int32x4_t t0 = vmulq_s32(c, EEO[i + 0]); + int32x4_t t1 = vmulq_s32(c, EEO[i + 1]); + int32x4_t t2 = vmulq_s32(c, EEO[i + 2]); + int32x4_t t3 = vmulq_s32(c, EEO[i + 3]); + + int32x4_t t = vpaddq_s32(vpaddq_s32(t0, t1), vpaddq_s32(t2, t3)); + int16x4_t res = vrshrn_n_s32(t, shift); + vst1_s16(d, res); + + d += 4; + } + } + + int32x4_t c0 = vld1q_s32(t8_even[0]); + int32x4_t c8 = vld1q_s32(t8_even[1]); + int32x4_t c16 = vld1q_s32(t8_even[2]); + int32x4_t c24 = vld1q_s32(t8_even[3]); + + for (int i = 0; i < line; i += 4) + { + int32x4_t t0 = vpaddq_s32(EEEE[i / 2 + 0], EEEE[i / 2 + 1]); + int32x4_t t1 = vmulq_s32(c0, t0); + int16x4_t res0 = vrshrn_n_s32(t1, shift); + vst1_s16(dst + 0 * line, res0); + + int32x4_t t2 = vmulq_s32(c8, EEEO[i / 2 + 0]); + int32x4_t t3 = vmulq_s32(c8, EEEO[i / 2 + 1]); + int16x4_t res8 = vrshrn_n_s32(vpaddq_s32(t2, t3), shift); + vst1_s16(dst + 8 * line, res8); + + int32x4_t t4 = vmulq_s32(c16, EEEE[i / 2 + 0]); + int32x4_t t5 = vmulq_s32(c16, EEEE[i / 2 + 1]); + int16x4_t res16 = vrshrn_n_s32(vpaddq_s32(t4, t5), shift); + vst1_s16(dst + 16 * line, res16); + + int32x4_t t6 = vmulq_s32(c24, EEEO[i / 2 + 0]); + int32x4_t t7 = vmulq_s32(c24, EEEO[i / 2 + 1]); + int16x4_t res24 = vrshrn_n_s32(vpaddq_s32(t6, t7), shift); + vst1_s16(dst + 24 * line, res24); + + dst += 4; + } +} + } @@ -279,10 +464,28 @@ void dct16_sve(const int16_t *src, int16_t *dst, intptr_t srcStride) partialButterfly16_sve<shift_pass2>(coef, dst); } +void dct32_sve(const int16_t *src, int16_t *dst, intptr_t srcStride) +{ + const int shift_pass1 = 4 + X265_DEPTH - 8; + const int shift_pass2 = 11; + + ALIGN_VAR_32(int16_t, coef[32 * 32]); + ALIGN_VAR_32(int16_t, block[32 * 32]); + + for (int i = 0; i < 32; i++) + { + memcpy(&block[i * 32], &src[i * srcStride], 32 * sizeof(int16_t)); + } + + partialButterfly32_sve<shift_pass1>(block, coef); + partialButterfly32_sve<shift_pass2>(coef, dst); +} + void setupDCTPrimitives_sve(EncoderPrimitives &p) { p.cu[BLOCK_8x8].dct = dct8_sve; p.cu[BLOCK_16x16].dct = dct16_sve; + p.cu[BLOCK_32x32].dct = dct32_sve; } }; diff --git a/source/common/aarch64/neon-sve-bridge.h b/source/common/aarch64/neon-sve-bridge.h index 59ca8aab7..dad5fa909 100644 --- a/source/common/aarch64/neon-sve-bridge.h +++ b/source/common/aarch64/neon-sve-bridge.h @@ -40,6 +40,11 @@ * remainder of the vector is unused - this approach is still beneficial when * compared to a Neon-only implementation. */ +static inline int32x4_t x265_vld1sh_s32(const int16_t *ptr) +{ + return svget_neonq_s32(svld1sh_s32(svptrue_pat_b32(SV_VL4), ptr)); +} + static inline int64x2_t x265_sdotq_s16(int64x2_t acc, int16x8_t x, int16x8_t y) { return svget_neonq_s64(svdot_s64(svset_neonq_s64(svundef_s64(), acc), -- 2.42.1
>From 4d6f4ae4154ff2b67375ba36e6c7f6e45e2b909a Mon Sep 17 00:00:00 2001 Message-ID: <4d6f4ae4154ff2b67375ba36e6c7f6e45e2b909a.1724771133.git.hari.lim...@arm.com> In-Reply-To: <cover.1724771133.git.hari.lim...@arm.com> References: <cover.1724771133.git.hari.lim...@arm.com> From: Jonathan Wright <jonathan.wri...@arm.com> Date: Wed, 21 Aug 2024 18:02:21 +0100 Subject: [PATCH v2 9/9] AArch64: Add SVE implementation of 32x32 DCT The widening 16-bit multiply + pairwise add pattern in the Neon DCT paths is a good fit for the SVE 16-bit dot-product instructions. This patch adds an SVE implementation of the 32x32 DCT path. Relative performance compared to the Neon implementation: Neoverse-V1: 1.13x Neoverse-V2: 1.44x Neoverse-N2: 1.55x --- source/common/aarch64/dct-prim-sve.cpp | 203 ++++++++++++++++++++++++ source/common/aarch64/neon-sve-bridge.h | 5 + 2 files changed, 208 insertions(+) diff --git a/source/common/aarch64/dct-prim-sve.cpp b/source/common/aarch64/dct-prim-sve.cpp index 3f6de3bff..75bc8d359 100644 --- a/source/common/aarch64/dct-prim-sve.cpp +++ b/source/common/aarch64/dct-prim-sve.cpp @@ -239,6 +239,191 @@ static inline void partialButterfly16_sve(const int16_t *src, int16_t *dst) } } +template<int shift> +static inline void partialButterfly32_sve(const int16_t *src, int16_t *dst) +{ + const int line = 32; + + int16x8_t O[line][2]; + int16x8_t EO[line]; + int32x4_t EEO[line]; + int32x4_t EEEE[line / 2]; + int32x4_t EEEO[line / 2]; + + for (int i = 0; i < line; i += 2) + { + int16x8x4_t in_lo = vld1q_s16_x4(src + (i + 0) * line); + in_lo.val[2] = rev16(in_lo.val[2]); + in_lo.val[3] = rev16(in_lo.val[3]); + + int16x8x4_t in_hi = vld1q_s16_x4(src + (i + 1) * line); + in_hi.val[2] = rev16(in_hi.val[2]); + in_hi.val[3] = rev16(in_hi.val[3]); + + int32x4_t E0[4]; + E0[0] = vaddl_s16(vget_low_s16(in_lo.val[0]), + vget_low_s16(in_lo.val[3])); + E0[1] = vaddl_s16(vget_high_s16(in_lo.val[0]), + vget_high_s16(in_lo.val[3])); + E0[2] = vaddl_s16(vget_low_s16(in_lo.val[1]), + vget_low_s16(in_lo.val[2])); + E0[3] = vaddl_s16(vget_high_s16(in_lo.val[1]), + vget_high_s16(in_lo.val[2])); + + int32x4_t E1[4]; + E1[0] = vaddl_s16(vget_low_s16(in_hi.val[0]), + vget_low_s16(in_hi.val[3])); + E1[1] = vaddl_s16(vget_high_s16(in_hi.val[0]), + vget_high_s16(in_hi.val[3])); + E1[2] = vaddl_s16(vget_low_s16(in_hi.val[1]), + vget_low_s16(in_hi.val[2])); + E1[3] = vaddl_s16(vget_high_s16(in_hi.val[1]), + vget_high_s16(in_hi.val[2])); + + O[i + 0][0] = vsubq_s16(in_lo.val[0], in_lo.val[3]); + O[i + 0][1] = vsubq_s16(in_lo.val[1], in_lo.val[2]); + + O[i + 1][0] = vsubq_s16(in_hi.val[0], in_hi.val[3]); + O[i + 1][1] = vsubq_s16(in_hi.val[1], in_hi.val[2]); + + int32x4_t EE0[2]; + E0[3] = rev32(E0[3]); + E0[2] = rev32(E0[2]); + EE0[0] = vaddq_s32(E0[0], E0[3]); + EE0[1] = vaddq_s32(E0[1], E0[2]); + EO[i + 0] = vcombine_s16(vmovn_s32(vsubq_s32(E0[0], E0[3])), + vmovn_s32(vsubq_s32(E0[1], E0[2]))); + + int32x4_t EE1[2]; + E1[3] = rev32(E1[3]); + E1[2] = rev32(E1[2]); + EE1[0] = vaddq_s32(E1[0], E1[3]); + EE1[1] = vaddq_s32(E1[1], E1[2]); + EO[i + 1] = vcombine_s16(vmovn_s32(vsubq_s32(E1[0], E1[3])), + vmovn_s32(vsubq_s32(E1[1], E1[2]))); + + int32x4_t EEE0; + EE0[1] = rev32(EE0[1]); + EEE0 = vaddq_s32(EE0[0], EE0[1]); + EEO[i + 0] = vsubq_s32(EE0[0], EE0[1]); + + int32x4_t EEE1; + EE1[1] = rev32(EE1[1]); + EEE1 = vaddq_s32(EE1[0], EE1[1]); + EEO[i + 1] = vsubq_s32(EE1[0], EE1[1]); + + int32x4_t t0 = vreinterpretq_s32_s64( + vzip1q_s64(vreinterpretq_s64_s32(EEE0), + vreinterpretq_s64_s32(EEE1))); + int32x4_t t1 = vrev64q_s32(vreinterpretq_s32_s64( + vzip2q_s64(vreinterpretq_s64_s32(EEE0), + vreinterpretq_s64_s32(EEE1)))); + + EEEE[i / 2] = vaddq_s32(t0, t1); + EEEO[i / 2] = vsubq_s32(t0, t1); + } + + for (int k = 1; k < 32; k += 2) + { + int16_t *d = dst + k * line; + + int16x8_t c0_c1 = vld1q_s16(&g_t32[k][0]); + int16x8_t c2_c3 = vld1q_s16(&g_t32[k][8]); + + for (int i = 0; i < line; i += 4) + { + int64x2_t t0 = x265_sdotq_s16(vdupq_n_s64(0), c0_c1, O[i + 0][0]); + int64x2_t t1 = x265_sdotq_s16(vdupq_n_s64(0), c0_c1, O[i + 1][0]); + int64x2_t t2 = x265_sdotq_s16(vdupq_n_s64(0), c0_c1, O[i + 2][0]); + int64x2_t t3 = x265_sdotq_s16(vdupq_n_s64(0), c0_c1, O[i + 3][0]); + t0 = x265_sdotq_s16(t0, c2_c3, O[i + 0][1]); + t1 = x265_sdotq_s16(t1, c2_c3, O[i + 1][1]); + t2 = x265_sdotq_s16(t2, c2_c3, O[i + 2][1]); + t3 = x265_sdotq_s16(t3, c2_c3, O[i + 3][1]); + + int32x4_t t01 = vcombine_s32(vmovn_s64(t0), vmovn_s64(t1)); + int32x4_t t23 = vcombine_s32(vmovn_s64(t2), vmovn_s64(t3)); + int16x4_t res = vrshrn_n_s32(vpaddq_s32(t01, t23), shift); + vst1_s16(d, res); + + d += 4; + } + } + + for (int k = 2; k < 32; k += 4) + { + int16_t *d = dst + k * line; + + int16x8_t c0 = vld1q_s16(&g_t32[k][0]); + + for (int i = 0; i < line; i += 4) + { + int64x2_t t0 = x265_sdotq_s16(vdupq_n_s64(0), c0, EO[i + 0]); + int64x2_t t1 = x265_sdotq_s16(vdupq_n_s64(0), c0, EO[i + 1]); + int64x2_t t2 = x265_sdotq_s16(vdupq_n_s64(0), c0, EO[i + 2]); + int64x2_t t3 = x265_sdotq_s16(vdupq_n_s64(0), c0, EO[i + 3]); + + int32x4_t t01 = vcombine_s32(vmovn_s64(t0), vmovn_s64(t1)); + int32x4_t t23 = vcombine_s32(vmovn_s64(t2), vmovn_s64(t3)); + int16x4_t res = vrshrn_n_s32(vpaddq_s32(t01, t23), shift); + vst1_s16(d, res); + + d += 4; + } + } + + for (int k = 4; k < 32; k += 8) + { + int16_t *d = dst + k * line; + + int32x4_t c = x265_vld1sh_s32(&g_t32[k][0]); + + for (int i = 0; i < line; i += 4) + { + int32x4_t t0 = vmulq_s32(c, EEO[i + 0]); + int32x4_t t1 = vmulq_s32(c, EEO[i + 1]); + int32x4_t t2 = vmulq_s32(c, EEO[i + 2]); + int32x4_t t3 = vmulq_s32(c, EEO[i + 3]); + + int32x4_t t = vpaddq_s32(vpaddq_s32(t0, t1), vpaddq_s32(t2, t3)); + int16x4_t res = vrshrn_n_s32(t, shift); + vst1_s16(d, res); + + d += 4; + } + } + + int32x4_t c0 = vld1q_s32(t8_even[0]); + int32x4_t c8 = vld1q_s32(t8_even[1]); + int32x4_t c16 = vld1q_s32(t8_even[2]); + int32x4_t c24 = vld1q_s32(t8_even[3]); + + for (int i = 0; i < line; i += 4) + { + int32x4_t t0 = vpaddq_s32(EEEE[i / 2 + 0], EEEE[i / 2 + 1]); + int32x4_t t1 = vmulq_s32(c0, t0); + int16x4_t res0 = vrshrn_n_s32(t1, shift); + vst1_s16(dst + 0 * line, res0); + + int32x4_t t2 = vmulq_s32(c8, EEEO[i / 2 + 0]); + int32x4_t t3 = vmulq_s32(c8, EEEO[i / 2 + 1]); + int16x4_t res8 = vrshrn_n_s32(vpaddq_s32(t2, t3), shift); + vst1_s16(dst + 8 * line, res8); + + int32x4_t t4 = vmulq_s32(c16, EEEE[i / 2 + 0]); + int32x4_t t5 = vmulq_s32(c16, EEEE[i / 2 + 1]); + int16x4_t res16 = vrshrn_n_s32(vpaddq_s32(t4, t5), shift); + vst1_s16(dst + 16 * line, res16); + + int32x4_t t6 = vmulq_s32(c24, EEEO[i / 2 + 0]); + int32x4_t t7 = vmulq_s32(c24, EEEO[i / 2 + 1]); + int16x4_t res24 = vrshrn_n_s32(vpaddq_s32(t6, t7), shift); + vst1_s16(dst + 24 * line, res24); + + dst += 4; + } +} + } @@ -279,10 +464,28 @@ void dct16_sve(const int16_t *src, int16_t *dst, intptr_t srcStride) partialButterfly16_sve<shift_pass2>(coef, dst); } +void dct32_sve(const int16_t *src, int16_t *dst, intptr_t srcStride) +{ + const int shift_pass1 = 4 + X265_DEPTH - 8; + const int shift_pass2 = 11; + + ALIGN_VAR_32(int16_t, coef[32 * 32]); + ALIGN_VAR_32(int16_t, block[32 * 32]); + + for (int i = 0; i < 32; i++) + { + memcpy(&block[i * 32], &src[i * srcStride], 32 * sizeof(int16_t)); + } + + partialButterfly32_sve<shift_pass1>(block, coef); + partialButterfly32_sve<shift_pass2>(coef, dst); +} + void setupDCTPrimitives_sve(EncoderPrimitives &p) { p.cu[BLOCK_8x8].dct = dct8_sve; p.cu[BLOCK_16x16].dct = dct16_sve; + p.cu[BLOCK_32x32].dct = dct32_sve; } }; diff --git a/source/common/aarch64/neon-sve-bridge.h b/source/common/aarch64/neon-sve-bridge.h index 59ca8aab7..dad5fa909 100644 --- a/source/common/aarch64/neon-sve-bridge.h +++ b/source/common/aarch64/neon-sve-bridge.h @@ -40,6 +40,11 @@ * remainder of the vector is unused - this approach is still beneficial when * compared to a Neon-only implementation. */ +static inline int32x4_t x265_vld1sh_s32(const int16_t *ptr) +{ + return svget_neonq_s32(svld1sh_s32(svptrue_pat_b32(SV_VL4), ptr)); +} + static inline int64x2_t x265_sdotq_s16(int64x2_t acc, int16x8_t x, int16x8_t y) { return svget_neonq_s64(svdot_s64(svset_neonq_s64(svundef_s64(), acc), -- 2.42.1
_______________________________________________ x265-devel mailing list x265-devel@videolan.org https://mailman.videolan.org/listinfo/x265-devel