Optimise the Neon implementation of partialButterfly16 to process four lines at a time to make use of the full width of Neon vector registers, avoiding widening to 32-bit values where possible, and replacing the addition of a rounding constant with rounding shift instructions.
Relative performance observed compared to the existing implementation: Neoverse N1: 1.31x Neoverse V1: 1.69x Neoverse N2: 1.29x Neoverse V2: 1.78x Co-authored-by: Jonathan Wright <jonathan.wright at arm.com> --- source/common/aarch64/dct-prim.cpp | 142 +++++++++++++++++++---------- 1 file changed, 96 insertions(+), 46 deletions(-) diff --git a/source/common/aarch64/dct-prim.cpp b/source/common/aarch64/dct-prim.cpp index e07872157..c880bc72c 100644 --- a/source/common/aarch64/dct-prim.cpp +++ b/source/common/aarch64/dct-prim.cpp @@ -250,63 +250,113 @@ uint32_t copy_count_neon(int16_t *coeff, const int16_t *residual, intptr_t resiS return numSig - vaddvq_s16(vcount); } - -static void partialButterfly16(const int16_t *src, int16_t *dst, int shift, int line) +template<int shift> +static inline void partialButterfly16_neon(const int16_t *src, int16_t *dst) { - int j, k; - int32x4_t E[2], O[2]; - int32x4_t EE, EO; - int32x2_t EEE, EEO; - const int add = 1 << (shift - 1); - const int32x4_t _vadd = {add, 0}; + const int line = 16; - for (j = 0; j < line; j++) + int16x8_t O[line]; + int32x4_t EO[line]; + int32x4_t EEE[line]; + int32x4_t EEO[line]; + + for (int i = 0; i < line; i += 2) { - int16x8_t in0 = vld1q_s16(src); - int16x8_t in1 = rev16(vld1q_s16(src + 8)); + int16x8_t s0_lo = vld1q_s16(src + i * line); + int16x8_t s0_hi = rev16(vld1q_s16(src + i * line + 8)); - E[0] = vaddl_s16(vget_low_s16(in0), vget_low_s16(in1)); - O[0] = vsubl_s16(vget_low_s16(in0), vget_low_s16(in1)); - E[1] = vaddl_high_s16(in0, in1); - O[1] = vsubl_high_s16(in0, in1); + int16x8_t s1_lo = vld1q_s16(src + (i + 1) * line); + int16x8_t s1_hi = rev16(vld1q_s16(src + (i + 1) * line + 8)); - for (k = 1; k < 16; k += 2) - { - int32x4_t c0 = vmovl_s16(vld1_s16(&g_t16[k][0])); - int32x4_t c1 = vmovl_s16(vld1_s16(&g_t16[k][4])); + int32x4_t E0[2]; + E0[0] = vaddl_s16(vget_low_s16(s0_lo), vget_low_s16(s0_hi)); + E0[1] = vaddl_s16(vget_high_s16(s0_lo), vget_high_s16(s0_hi)); - int32x4_t res = _vadd; - res = vmlaq_s32(res, c0, O[0]); - res = vmlaq_s32(res, c1, O[1]); - dst[k * line] = (int16_t)(vaddvq_s32(res) >> shift); - } + int32x4_t E1[2]; + E1[0] = vaddl_s16(vget_low_s16(s1_lo), vget_low_s16(s1_hi)); + E1[1] = vaddl_s16(vget_high_s16(s1_lo), vget_high_s16(s1_hi)); - /* EE and EO */ - EE = vaddq_s32(E[0], rev32(E[1])); - EO = vsubq_s32(E[0], rev32(E[1])); + O[i + 0] = vsubq_s16(s0_lo, s0_hi); + O[i + 1] = vsubq_s16(s1_lo, s1_hi); + + int32x4_t EE0 = vaddq_s32(E0[0], rev32(E0[1])); + int32x4_t EE1 = vaddq_s32(E1[0], rev32(E1[1])); + EO[i + 0] = vsubq_s32(E0[0], rev32(E0[1])); + EO[i + 1] = vsubq_s32(E1[0], rev32(E1[1])); + + int32x4_t t0 = vreinterpretq_s32_s64( + vzip1q_s64(vreinterpretq_s64_s32(EE0), vreinterpretq_s64_s32(EE1))); + int32x4_t t1 = vrev64q_s32(vreinterpretq_s32_s64(vzip2q_s64( + vreinterpretq_s64_s32(EE0), vreinterpretq_s64_s32(EE1)))); - for (k = 2; k < 16; k += 4) + + EEE[i / 2] = vaddq_s32(t0, t1); + EEO[i / 2] = vsubq_s32(t0, t1); + } + + for (int i = 0; i < line; i += 4) + { + for (int k = 1; k < 16; k += 2) + { + int16x8_t c0_c4 = vld1q_s16(&g_t16[k][0]); + + int32x4_t t0 = vmull_s16(vget_low_s16(c0_c4), + vget_low_s16(O[i + 0])); + int32x4_t t1 = vmull_s16(vget_low_s16(c0_c4), + vget_low_s16(O[i + 1])); + int32x4_t t2 = vmull_s16(vget_low_s16(c0_c4), + vget_low_s16(O[i + 2])); + int32x4_t t3 = vmull_s16(vget_low_s16(c0_c4), + vget_low_s16(O[i + 3])); + t0 = vmlal_s16(t0, vget_high_s16(c0_c4), vget_high_s16(O[i + 0])); + t1 = vmlal_s16(t1, vget_high_s16(c0_c4), vget_high_s16(O[i + 1])); + t2 = vmlal_s16(t2, vget_high_s16(c0_c4), vget_high_s16(O[i + 2])); + t3 = vmlal_s16(t3, vget_high_s16(c0_c4), vget_high_s16(O[i + 3])); + + int32x4_t t = vpaddq_s32(vpaddq_s32(t0, t1), vpaddq_s32(t2, t3)); + int16x4_t res = vrshrn_n_s32(t, shift); + vst1_s16(dst + k * line, res); + } + + for (int k = 2; k < 16; k += 4) { int32x4_t c0 = vmovl_s16(vld1_s16(&g_t16[k][0])); - int32x4_t res = _vadd; - res = vmlaq_s32(res, c0, EO); - dst[k * line] = (int16_t)(vaddvq_s32(res) >> shift); + int32x4_t t0 = vmulq_s32(c0, EO[i + 0]); + int32x4_t t1 = vmulq_s32(c0, EO[i + 1]); + int32x4_t t2 = vmulq_s32(c0, EO[i + 2]); + int32x4_t t3 = vmulq_s32(c0, EO[i + 3]); + int32x4_t t = vpaddq_s32(vpaddq_s32(t0, t1), vpaddq_s32(t2, t3)); + + int16x4_t res = vrshrn_n_s32(t, shift); + vst1_s16(dst + k * line, res); } - /* EEE and EEO */ - EEE[0] = EE[0] + EE[3]; - EEO[0] = EE[0] - EE[3]; - EEE[1] = EE[1] + EE[2]; - EEO[1] = EE[1] - EE[2]; + int32x4_t c0 = vld1q_s32(t8_even[0]); + int32x4_t c4 = vld1q_s32(t8_even[1]); + int32x4_t c8 = vld1q_s32(t8_even[2]); + int32x4_t c12 = vld1q_s32(t8_even[3]); - dst[0] = (int16_t)((g_t16[0][0] * EEE[0] + g_t16[0][1] * EEE[1] + add) >> shift); - dst[8 * line] = (int16_t)((g_t16[8][0] * EEE[0] + g_t16[8][1] * EEE[1] + add) >> shift); - dst[4 * line] = (int16_t)((g_t16[4][0] * EEO[0] + g_t16[4][1] * EEO[1] + add) >> shift); - dst[12 * line] = (int16_t)((g_t16[12][0] * EEO[0] + g_t16[12][1] * EEO[1] + add) >> shift); + int32x4_t t0 = vpaddq_s32(EEE[i / 2 + 0], EEE[i / 2 + 1]); + int32x4_t t1 = vmulq_s32(c0, t0); + int16x4_t res0 = vrshrn_n_s32(t1, shift); + vst1_s16(dst + 0 * line, res0); + int32x4_t t2 = vmulq_s32(c4, EEO[i / 2 + 0]); + int32x4_t t3 = vmulq_s32(c4, EEO[i / 2 + 1]); + int16x4_t res4 = vrshrn_n_s32(vpaddq_s32(t2, t3), shift); + vst1_s16(dst + 4 * line, res4); - src += 16; - dst++; + int32x4_t t4 = vmulq_s32(c8, EEE[i / 2 + 0]); + int32x4_t t5 = vmulq_s32(c8, EEE[i / 2 + 1]); + int16x4_t res8 = vrshrn_n_s32(vpaddq_s32(t4, t5), shift); + vst1_s16(dst + 8 * line, res8); + + int32x4_t t6 = vmulq_s32(c12, EEO[i / 2 + 0]); + int32x4_t t7 = vmulq_s32(c12, EEO[i / 2 + 1]); + int16x4_t res12 = vrshrn_n_s32(vpaddq_s32(t6, t7), shift); + vst1_s16(dst + 12 * line, res12); + + dst += 4; } } @@ -898,8 +948,8 @@ void dct8_neon(const int16_t *src, int16_t *dst, intptr_t srcStride) void dct16_neon(const int16_t *src, int16_t *dst, intptr_t srcStride) { - const int shift_1st = 3 + X265_DEPTH - 8; - const int shift_2nd = 10; + const int shift_pass1 = 3 + X265_DEPTH - 8; + const int shift_pass2 = 10; ALIGN_VAR_32(int16_t, coef[16 * 16]); ALIGN_VAR_32(int16_t, block[16 * 16]); @@ -909,8 +959,8 @@ void dct16_neon(const int16_t *src, int16_t *dst, intptr_t srcStride) memcpy(&block[i * 16], &src[i * srcStride], 16 * sizeof(int16_t)); } - partialButterfly16(block, coef, shift_1st, 16); - partialButterfly16(coef, dst, shift_2nd, 16); + partialButterfly16_neon<shift_pass1>(block, coef); + partialButterfly16_neon<shift_pass2>(coef, dst); } void dct32_neon(const int16_t *src, int16_t *dst, intptr_t srcStride) -- 2.42.1
>From 8a2704d994bcd1beedf13a3d51e67f9726a58048 Mon Sep 17 00:00:00 2001 Message-ID: <8a2704d994bcd1beedf13a3d51e67f9726a58048.1724771133.git.hari.lim...@arm.com> In-Reply-To: <cover.1724771133.git.hari.lim...@arm.com> References: <cover.1724771133.git.hari.lim...@arm.com> From: Hari Limaye <hari.lim...@arm.com> Date: Wed, 6 Mar 2024 18:13:48 +0000 Subject: [PATCH v2 4/9] AArch64: Optimise partialButterfly16_neon Optimise the Neon implementation of partialButterfly16 to process four lines at a time to make use of the full width of Neon vector registers, avoiding widening to 32-bit values where possible, and replacing the addition of a rounding constant with rounding shift instructions. Relative performance observed compared to the existing implementation: Neoverse N1: 1.31x Neoverse V1: 1.69x Neoverse N2: 1.29x Neoverse V2: 1.78x Co-authored-by: Jonathan Wright <jonathan.wright at arm.com> --- source/common/aarch64/dct-prim.cpp | 142 +++++++++++++++++++---------- 1 file changed, 96 insertions(+), 46 deletions(-) diff --git a/source/common/aarch64/dct-prim.cpp b/source/common/aarch64/dct-prim.cpp index e07872157..c880bc72c 100644 --- a/source/common/aarch64/dct-prim.cpp +++ b/source/common/aarch64/dct-prim.cpp @@ -250,63 +250,113 @@ uint32_t copy_count_neon(int16_t *coeff, const int16_t *residual, intptr_t resiS return numSig - vaddvq_s16(vcount); } - -static void partialButterfly16(const int16_t *src, int16_t *dst, int shift, int line) +template<int shift> +static inline void partialButterfly16_neon(const int16_t *src, int16_t *dst) { - int j, k; - int32x4_t E[2], O[2]; - int32x4_t EE, EO; - int32x2_t EEE, EEO; - const int add = 1 << (shift - 1); - const int32x4_t _vadd = {add, 0}; + const int line = 16; - for (j = 0; j < line; j++) + int16x8_t O[line]; + int32x4_t EO[line]; + int32x4_t EEE[line]; + int32x4_t EEO[line]; + + for (int i = 0; i < line; i += 2) { - int16x8_t in0 = vld1q_s16(src); - int16x8_t in1 = rev16(vld1q_s16(src + 8)); + int16x8_t s0_lo = vld1q_s16(src + i * line); + int16x8_t s0_hi = rev16(vld1q_s16(src + i * line + 8)); - E[0] = vaddl_s16(vget_low_s16(in0), vget_low_s16(in1)); - O[0] = vsubl_s16(vget_low_s16(in0), vget_low_s16(in1)); - E[1] = vaddl_high_s16(in0, in1); - O[1] = vsubl_high_s16(in0, in1); + int16x8_t s1_lo = vld1q_s16(src + (i + 1) * line); + int16x8_t s1_hi = rev16(vld1q_s16(src + (i + 1) * line + 8)); - for (k = 1; k < 16; k += 2) - { - int32x4_t c0 = vmovl_s16(vld1_s16(&g_t16[k][0])); - int32x4_t c1 = vmovl_s16(vld1_s16(&g_t16[k][4])); + int32x4_t E0[2]; + E0[0] = vaddl_s16(vget_low_s16(s0_lo), vget_low_s16(s0_hi)); + E0[1] = vaddl_s16(vget_high_s16(s0_lo), vget_high_s16(s0_hi)); - int32x4_t res = _vadd; - res = vmlaq_s32(res, c0, O[0]); - res = vmlaq_s32(res, c1, O[1]); - dst[k * line] = (int16_t)(vaddvq_s32(res) >> shift); - } + int32x4_t E1[2]; + E1[0] = vaddl_s16(vget_low_s16(s1_lo), vget_low_s16(s1_hi)); + E1[1] = vaddl_s16(vget_high_s16(s1_lo), vget_high_s16(s1_hi)); - /* EE and EO */ - EE = vaddq_s32(E[0], rev32(E[1])); - EO = vsubq_s32(E[0], rev32(E[1])); + O[i + 0] = vsubq_s16(s0_lo, s0_hi); + O[i + 1] = vsubq_s16(s1_lo, s1_hi); + + int32x4_t EE0 = vaddq_s32(E0[0], rev32(E0[1])); + int32x4_t EE1 = vaddq_s32(E1[0], rev32(E1[1])); + EO[i + 0] = vsubq_s32(E0[0], rev32(E0[1])); + EO[i + 1] = vsubq_s32(E1[0], rev32(E1[1])); + + int32x4_t t0 = vreinterpretq_s32_s64( + vzip1q_s64(vreinterpretq_s64_s32(EE0), vreinterpretq_s64_s32(EE1))); + int32x4_t t1 = vrev64q_s32(vreinterpretq_s32_s64(vzip2q_s64( + vreinterpretq_s64_s32(EE0), vreinterpretq_s64_s32(EE1)))); - for (k = 2; k < 16; k += 4) + + EEE[i / 2] = vaddq_s32(t0, t1); + EEO[i / 2] = vsubq_s32(t0, t1); + } + + for (int i = 0; i < line; i += 4) + { + for (int k = 1; k < 16; k += 2) + { + int16x8_t c0_c4 = vld1q_s16(&g_t16[k][0]); + + int32x4_t t0 = vmull_s16(vget_low_s16(c0_c4), + vget_low_s16(O[i + 0])); + int32x4_t t1 = vmull_s16(vget_low_s16(c0_c4), + vget_low_s16(O[i + 1])); + int32x4_t t2 = vmull_s16(vget_low_s16(c0_c4), + vget_low_s16(O[i + 2])); + int32x4_t t3 = vmull_s16(vget_low_s16(c0_c4), + vget_low_s16(O[i + 3])); + t0 = vmlal_s16(t0, vget_high_s16(c0_c4), vget_high_s16(O[i + 0])); + t1 = vmlal_s16(t1, vget_high_s16(c0_c4), vget_high_s16(O[i + 1])); + t2 = vmlal_s16(t2, vget_high_s16(c0_c4), vget_high_s16(O[i + 2])); + t3 = vmlal_s16(t3, vget_high_s16(c0_c4), vget_high_s16(O[i + 3])); + + int32x4_t t = vpaddq_s32(vpaddq_s32(t0, t1), vpaddq_s32(t2, t3)); + int16x4_t res = vrshrn_n_s32(t, shift); + vst1_s16(dst + k * line, res); + } + + for (int k = 2; k < 16; k += 4) { int32x4_t c0 = vmovl_s16(vld1_s16(&g_t16[k][0])); - int32x4_t res = _vadd; - res = vmlaq_s32(res, c0, EO); - dst[k * line] = (int16_t)(vaddvq_s32(res) >> shift); + int32x4_t t0 = vmulq_s32(c0, EO[i + 0]); + int32x4_t t1 = vmulq_s32(c0, EO[i + 1]); + int32x4_t t2 = vmulq_s32(c0, EO[i + 2]); + int32x4_t t3 = vmulq_s32(c0, EO[i + 3]); + int32x4_t t = vpaddq_s32(vpaddq_s32(t0, t1), vpaddq_s32(t2, t3)); + + int16x4_t res = vrshrn_n_s32(t, shift); + vst1_s16(dst + k * line, res); } - /* EEE and EEO */ - EEE[0] = EE[0] + EE[3]; - EEO[0] = EE[0] - EE[3]; - EEE[1] = EE[1] + EE[2]; - EEO[1] = EE[1] - EE[2]; + int32x4_t c0 = vld1q_s32(t8_even[0]); + int32x4_t c4 = vld1q_s32(t8_even[1]); + int32x4_t c8 = vld1q_s32(t8_even[2]); + int32x4_t c12 = vld1q_s32(t8_even[3]); - dst[0] = (int16_t)((g_t16[0][0] * EEE[0] + g_t16[0][1] * EEE[1] + add) >> shift); - dst[8 * line] = (int16_t)((g_t16[8][0] * EEE[0] + g_t16[8][1] * EEE[1] + add) >> shift); - dst[4 * line] = (int16_t)((g_t16[4][0] * EEO[0] + g_t16[4][1] * EEO[1] + add) >> shift); - dst[12 * line] = (int16_t)((g_t16[12][0] * EEO[0] + g_t16[12][1] * EEO[1] + add) >> shift); + int32x4_t t0 = vpaddq_s32(EEE[i / 2 + 0], EEE[i / 2 + 1]); + int32x4_t t1 = vmulq_s32(c0, t0); + int16x4_t res0 = vrshrn_n_s32(t1, shift); + vst1_s16(dst + 0 * line, res0); + int32x4_t t2 = vmulq_s32(c4, EEO[i / 2 + 0]); + int32x4_t t3 = vmulq_s32(c4, EEO[i / 2 + 1]); + int16x4_t res4 = vrshrn_n_s32(vpaddq_s32(t2, t3), shift); + vst1_s16(dst + 4 * line, res4); - src += 16; - dst++; + int32x4_t t4 = vmulq_s32(c8, EEE[i / 2 + 0]); + int32x4_t t5 = vmulq_s32(c8, EEE[i / 2 + 1]); + int16x4_t res8 = vrshrn_n_s32(vpaddq_s32(t4, t5), shift); + vst1_s16(dst + 8 * line, res8); + + int32x4_t t6 = vmulq_s32(c12, EEO[i / 2 + 0]); + int32x4_t t7 = vmulq_s32(c12, EEO[i / 2 + 1]); + int16x4_t res12 = vrshrn_n_s32(vpaddq_s32(t6, t7), shift); + vst1_s16(dst + 12 * line, res12); + + dst += 4; } } @@ -898,8 +948,8 @@ void dct8_neon(const int16_t *src, int16_t *dst, intptr_t srcStride) void dct16_neon(const int16_t *src, int16_t *dst, intptr_t srcStride) { - const int shift_1st = 3 + X265_DEPTH - 8; - const int shift_2nd = 10; + const int shift_pass1 = 3 + X265_DEPTH - 8; + const int shift_pass2 = 10; ALIGN_VAR_32(int16_t, coef[16 * 16]); ALIGN_VAR_32(int16_t, block[16 * 16]); @@ -909,8 +959,8 @@ void dct16_neon(const int16_t *src, int16_t *dst, intptr_t srcStride) memcpy(&block[i * 16], &src[i * srcStride], 16 * sizeof(int16_t)); } - partialButterfly16(block, coef, shift_1st, 16); - partialButterfly16(coef, dst, shift_2nd, 16); + partialButterfly16_neon<shift_pass1>(block, coef); + partialButterfly16_neon<shift_pass2>(coef, dst); } void dct32_neon(const int16_t *src, int16_t *dst, intptr_t srcStride) -- 2.42.1
_______________________________________________ x265-devel mailing list x265-devel@videolan.org https://mailman.videolan.org/listinfo/x265-devel