On Wed, Mar 20, 2024 at 09:31:16AM -0500, Nathan Bossart wrote:
> On Wed, Mar 20, 2024 at 01:57:54PM +0700, John Naylor wrote:
>> On Tue, Mar 19, 2024 at 11:30 PM Nathan Bossart
>> <nathandboss...@gmail.com> wrote:
>>> I tried to trim some of the branches, and came up with the attached patch.
>>> I don't think this is exactly what you were suggesting, but I think it's
>>> relatively close.  My testing showed decent benefits from using 2 vectors
>>> when there aren't enough elements for 4, so I've tried to keep that part
>>> intact.
>> 
>> I would caution against that if the benchmark is repeatedly running
>> against a static number of elements, because the branch predictor will
>> be right all the time (except maybe when it exits a loop, not sure).
>> We probably don't need to go to the trouble to construct a benchmark
>> with some added randomness, but we have be careful not to overfit what
>> the test is actually measuring.
> 
> I don't mind removing the 2-register stuff if that's what you think we
> should do.  I'm cautiously optimistic that it'd help more than the extra
> branch prediction might hurt, and it'd at least help avoid regressing the
> lower end for the larger AVX2 registers, but I probably won't be able to
> prove that without constructing another benchmark.  And TBH I'm not sure
> it'll significantly impact any real-world workload, anyway.

Here's a new version of the patch set with the 2-register stuff removed,
plus a fresh run of the benchmark.  The weird spike for AVX2 is what led me
down the 2-register path earlier.

-- 
Nathan Bossart
Amazon Web Services: https://aws.amazon.com
>From d47b3219fd1b803a5dedff9babaa5134c07e6947 Mon Sep 17 00:00:00 2001
From: Nathan Bossart <nat...@postgresql.org>
Date: Wed, 20 Mar 2024 14:20:24 -0500
Subject: [PATCH v5 1/2] pg_lfind32(): add "overlap" code for remaining
 elements

---
 src/include/port/pg_lfind.h | 102 +++++++++++++++++++++++++-----------
 1 file changed, 71 insertions(+), 31 deletions(-)

diff --git a/src/include/port/pg_lfind.h b/src/include/port/pg_lfind.h
index b8dfa66eef..21af399dc4 100644
--- a/src/include/port/pg_lfind.h
+++ b/src/include/port/pg_lfind.h
@@ -80,6 +80,49 @@ pg_lfind8_le(uint8 key, uint8 *base, uint32 nelem)
 	return false;
 }
 
+/*
+ * pg_lfind32_helper
+ *
+ * Searches one 4-register-block of integers.  The caller is responsible for
+ * ensuring that there are at least 4-registers-worth of integers remaining.
+ */
+static inline bool
+pg_lfind32_helper(const Vector32 keys, uint32 *base)
+{
+	const uint32 nelem_per_vector = sizeof(Vector32) / sizeof(uint32);
+	Vector32	vals1,
+				vals2,
+				vals3,
+				vals4,
+				result1,
+				result2,
+				result3,
+				result4,
+				tmp1,
+				tmp2,
+				result;
+
+	/* load the next block into 4 registers */
+	vector32_load(&vals1, base);
+	vector32_load(&vals2, &base[nelem_per_vector]);
+	vector32_load(&vals3, &base[nelem_per_vector * 2]);
+	vector32_load(&vals4, &base[nelem_per_vector * 3]);
+
+	/* compare each value to the key */
+	result1 = vector32_eq(keys, vals1);
+	result2 = vector32_eq(keys, vals2);
+	result3 = vector32_eq(keys, vals3);
+	result4 = vector32_eq(keys, vals4);
+
+	/* combine the results into a single variable */
+	tmp1 = vector32_or(result1, result2);
+	tmp2 = vector32_or(result3, result4);
+	result = vector32_or(tmp1, tmp2);
+
+	/* return whether there was a match */
+	return vector32_is_highbit_set(result);
+}
+
 /*
  * pg_lfind32
  *
@@ -119,46 +162,43 @@ pg_lfind32(uint32 key, uint32 *base, uint32 nelem)
 	}
 #endif
 
+	/*
+	 * If there aren't enough elements for the SIMD code, jump to the standard
+	 * one-by-one linear search code.
+	 */
+	if (nelem < nelem_per_iteration)
+		goto one_by_one;
+
+	/*
+	 * Process as many elements as possible with a block of 4 registers.
+	 */
 	for (i = 0; i < tail_idx; i += nelem_per_iteration)
 	{
-		Vector32	vals1,
-					vals2,
-					vals3,
-					vals4,
-					result1,
-					result2,
-					result3,
-					result4,
-					tmp1,
-					tmp2,
-					result;
-
-		/* load the next block into 4 registers */
-		vector32_load(&vals1, &base[i]);
-		vector32_load(&vals2, &base[i + nelem_per_vector]);
-		vector32_load(&vals3, &base[i + nelem_per_vector * 2]);
-		vector32_load(&vals4, &base[i + nelem_per_vector * 3]);
-
-		/* compare each value to the key */
-		result1 = vector32_eq(keys, vals1);
-		result2 = vector32_eq(keys, vals2);
-		result3 = vector32_eq(keys, vals3);
-		result4 = vector32_eq(keys, vals4);
-
-		/* combine the results into a single variable */
-		tmp1 = vector32_or(result1, result2);
-		tmp2 = vector32_or(result3, result4);
-		result = vector32_or(tmp1, tmp2);
-
-		/* see if there was a match */
-		if (vector32_is_highbit_set(result))
+		if (pg_lfind32_helper(keys, &base[i]))
 		{
 			Assert(assert_result == true);
 			return true;
 		}
 	}
+
+	/*
+	 * If any elements remain, process the last 'nelem_per_iteration' elements
+	 * in the array with a 4-register block.  This will cause us to check some
+	 * elements more than once, but that won't affect correctness, and testing
+	 * has demonstrated that this helps more cases than it harms.
+	 */
+	if (i != nelem &&
+		pg_lfind32_helper(keys, &base[nelem - nelem_per_iteration]))
+	{
+		Assert(assert_result);
+		return true;
+	}
+
+	Assert(!assert_result);
+	return false;
 #endif							/* ! USE_NO_SIMD */
 
+one_by_one:
 	/* Process the remaining elements one at a time. */
 	for (; i < nelem; i++)
 	{
-- 
2.25.1

>From e8337b123d828671d5c547d2a96485ef15f4ddfe Mon Sep 17 00:00:00 2001
From: Nathan Bossart <nat...@postgresql.org>
Date: Mon, 18 Mar 2024 11:02:05 -0500
Subject: [PATCH v5 2/2] Add support for AVX2 in simd.h.

Discussion: https://postgr.es/m/20231129171526.GA857928%40nathanxps13
---
 src/include/port/simd.h | 61 ++++++++++++++++++++++++++++++++---------
 1 file changed, 48 insertions(+), 13 deletions(-)

diff --git a/src/include/port/simd.h b/src/include/port/simd.h
index 597496f2fb..f06b21876b 100644
--- a/src/include/port/simd.h
+++ b/src/include/port/simd.h
@@ -18,7 +18,18 @@
 #ifndef SIMD_H
 #define SIMD_H
 
-#if (defined(__x86_64__) || defined(_M_AMD64))
+#if defined(__AVX2__)
+
+/*
+ * XXX: Need to add a big comment here.
+ */
+#include <immintrin.h>
+#define USE_AVX2
+typedef __m256i Vector8;
+typedef __m256i Vector32;
+
+#elif (defined(__x86_64__) || defined(_M_AMD64))
+
 /*
  * SSE2 instructions are part of the spec for the 64-bit x86 ISA. We assume
  * that compilers targeting this architecture understand SSE2 intrinsics.
@@ -107,7 +118,9 @@ static inline Vector32 vector32_eq(const Vector32 v1, const Vector32 v2);
 static inline void
 vector8_load(Vector8 *v, const uint8 *s)
 {
-#if defined(USE_SSE2)
+#if defined(USE_AVX2)
+	*v = _mm256_loadu_si256((const __m256i *) s);
+#elif defined(USE_SSE2)
 	*v = _mm_loadu_si128((const __m128i *) s);
 #elif defined(USE_NEON)
 	*v = vld1q_u8(s);
@@ -120,7 +133,9 @@ vector8_load(Vector8 *v, const uint8 *s)
 static inline void
 vector32_load(Vector32 *v, const uint32 *s)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	*v = _mm256_loadu_si256((const __m256i *) s);
+#elif defined(USE_SSE2)
 	*v = _mm_loadu_si128((const __m128i *) s);
 #elif defined(USE_NEON)
 	*v = vld1q_u32(s);
@@ -134,7 +149,9 @@ vector32_load(Vector32 *v, const uint32 *s)
 static inline Vector8
 vector8_broadcast(const uint8 c)
 {
-#if defined(USE_SSE2)
+#if defined(USE_AVX2)
+	return _mm256_set1_epi8(c);
+#elif defined(USE_SSE2)
 	return _mm_set1_epi8(c);
 #elif defined(USE_NEON)
 	return vdupq_n_u8(c);
@@ -147,7 +164,9 @@ vector8_broadcast(const uint8 c)
 static inline Vector32
 vector32_broadcast(const uint32 c)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_set1_epi32(c);
+#elif defined(USE_SSE2)
 	return _mm_set1_epi32(c);
 #elif defined(USE_NEON)
 	return vdupq_n_u32(c);
@@ -270,7 +289,9 @@ vector8_has_le(const Vector8 v, const uint8 c)
 static inline bool
 vector8_is_highbit_set(const Vector8 v)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_movemask_epi8(v) != 0;
+#elif defined(USE_SSE2)
 	return _mm_movemask_epi8(v) != 0;
 #elif defined(USE_NEON)
 	return vmaxvq_u8(v) > 0x7F;
@@ -308,7 +329,9 @@ vector32_is_highbit_set(const Vector32 v)
 static inline uint32
 vector8_highbit_mask(const Vector8 v)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return (uint32) _mm256_movemask_epi8(v);
+#elif defined(USE_SSE2)
 	return (uint32) _mm_movemask_epi8(v);
 #elif defined(USE_NEON)
 	/*
@@ -337,7 +360,9 @@ vector8_highbit_mask(const Vector8 v)
 static inline Vector8
 vector8_or(const Vector8 v1, const Vector8 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_or_si256(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_or_si128(v1, v2);
 #elif defined(USE_NEON)
 	return vorrq_u8(v1, v2);
@@ -350,7 +375,9 @@ vector8_or(const Vector8 v1, const Vector8 v2)
 static inline Vector32
 vector32_or(const Vector32 v1, const Vector32 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_or_si256(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_or_si128(v1, v2);
 #elif defined(USE_NEON)
 	return vorrq_u32(v1, v2);
@@ -368,7 +395,9 @@ vector32_or(const Vector32 v1, const Vector32 v2)
 static inline Vector8
 vector8_ssub(const Vector8 v1, const Vector8 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_subs_epu8(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_subs_epu8(v1, v2);
 #elif defined(USE_NEON)
 	return vqsubq_u8(v1, v2);
@@ -384,7 +413,9 @@ vector8_ssub(const Vector8 v1, const Vector8 v2)
 static inline Vector8
 vector8_eq(const Vector8 v1, const Vector8 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_cmpeq_epi8(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_cmpeq_epi8(v1, v2);
 #elif defined(USE_NEON)
 	return vceqq_u8(v1, v2);
@@ -396,7 +427,9 @@ vector8_eq(const Vector8 v1, const Vector8 v2)
 static inline Vector32
 vector32_eq(const Vector32 v1, const Vector32 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_cmpeq_epi32(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_cmpeq_epi32(v1, v2);
 #elif defined(USE_NEON)
 	return vceqq_u32(v1, v2);
@@ -411,7 +444,9 @@ vector32_eq(const Vector32 v1, const Vector32 v2)
 static inline Vector8
 vector8_min(const Vector8 v1, const Vector8 v2)
 {
-#ifdef USE_SSE2
+#if defined(USE_AVX2)
+	return _mm256_min_epu8(v1, v2);
+#elif defined(USE_SSE2)
 	return _mm_min_epu8(v1, v2);
 #elif defined(USE_NEON)
 	return vminq_u8(v1, v2);
-- 
2.25.1

Reply via email to