This is an automated email from the ASF dual-hosted git repository.
zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 32fcd184da GH-44513: [C++] Fix overflow issues for large build side in
swiss join (#45108)
32fcd184da is described below
commit 32fcd184da91e5d9bc9098baeef4f368632fc1f1
Author: Rossi Sun <[email protected]>
AuthorDate: Mon Jan 13 22:29:50 2025 +0800
GH-44513: [C++] Fix overflow issues for large build side in swiss join
(#45108)
### Rationale for this change
#44513 triggers two distinct overflow issues within swiss join, both
happening when the build side table contains large enough number of rows or
distinct keys. (Cases at this extent of hash join build side are rather rare,
so we haven't seen them reported until now):
1. The first issue is, our swiss table implementation takes the higher `N`
bits of 32-bit hash value as the index to a buffer storing "block"s (a block
contains `8` key - in some code also referred to as "group" - ids). This
`N`-bit number is further multiplied by the size of a block, which is also
related to `N`. The `N` in the case of #44513 is `26` and a block takes `40`
bytes. So the multiply is possible to produce a number over `1 << 31` (negative
when interpreted as signed 32bit [...]
2. The other issue is, we take `7` bits of the 32-bit hash value after `N`
as a "stamp" (to quick fail the hash comparison). But when `N` is greater than
`25`, some arithmetic code like
https://github.com/apache/arrow/blob/0a00e25f2f6fb927fb555b69038d0be9b9d9f265/cpp/src/arrow/compute/key_map_internal.cc#L397
(`bits_hash_` is `constexpr 32`, `log_blocks_` is `N`, `bits_stamp_` is
`constexpr 7`, this is to retrieve the stamp from a hash) produces `hash >> -1`
aka `hash >> 0xFFFFFFFF` a [...]
### What changes are included in this PR?
For issue 1, use 64-bit index gather intrinsic to avoid the offset overflow.
For issue 2, do not right-shift the hash if `N + 7 >= 32`. This is actually
allowing the bits overlapping between block id (the `N` bits) and stamp (the
`7` bits). Though this may introduce more false-positive hash comparisons (thus
worsen the performance), I think this is still more reasonable than brutally
failing for `N > 25`. I introduce two members `bits_shift_for_block_and_stamp_`
and `bits_shift_for_block_`, which are derived from `log_blocks_` - esp. set to
`0` and `32 - N` wh [...]
### Are these changes tested?
The fix is manually tested with the original case in my local. (I do have a
concrete C++ UT to verify the fix but it requires too much resource and runs
for too long time so it is impractical to run in any reasonable CI environment.)
### Are there any user-facing changes?
None.
* GitHub Issue: #44513
Lead-authored-by: Rossi Sun <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
---
cpp/src/arrow/compute/key_map_internal.cc | 21 ++++++----
cpp/src/arrow/compute/key_map_internal.h | 25 +++++++++++-
cpp/src/arrow/compute/key_map_internal_avx2.cc | 55 +++++++++++++++-----------
3 files changed, 69 insertions(+), 32 deletions(-)
diff --git a/cpp/src/arrow/compute/key_map_internal.cc
b/cpp/src/arrow/compute/key_map_internal.cc
index f134c91455..ad264533bf 100644
--- a/cpp/src/arrow/compute/key_map_internal.cc
+++ b/cpp/src/arrow/compute/key_map_internal.cc
@@ -254,9 +254,9 @@ void SwissTable::early_filter_imp(const int num_keys, const
uint32_t* hashes,
// Extract from hash: block index and stamp
//
uint32_t hash = hashes[i];
- uint32_t iblock = hash >> (bits_hash_ - bits_stamp_ - log_blocks_);
+ uint32_t iblock = hash >> bits_shift_for_block_and_stamp_;
uint32_t stamp = iblock & stamp_mask;
- iblock >>= bits_stamp_;
+ iblock >>= bits_shift_for_block_;
uint32_t num_block_bytes = num_groupid_bits + 8;
const uint8_t* blockbase =
@@ -399,7 +399,7 @@ bool SwissTable::find_next_stamp_match(const uint32_t hash,
const uint32_t in_sl
const uint64_t num_groupid_bits =
num_groupid_bits_from_log_blocks(log_blocks_);
constexpr uint64_t stamp_mask = 0x7f;
const int stamp =
- static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) &
stamp_mask);
+ static_cast<int>((hash >> bits_shift_for_block_and_stamp_) & stamp_mask);
uint64_t start_slot_id = wrap_global_slot_id(in_slot_id);
int match_found;
int local_slot;
@@ -659,6 +659,9 @@ Status SwissTable::grow_double() {
int num_group_id_bits_after = num_groupid_bits_from_log_blocks(log_blocks_ +
1);
uint64_t group_id_mask_before = ~0ULL >> (64 - num_group_id_bits_before);
int log_blocks_after = log_blocks_ + 1;
+ int bits_shift_for_block_and_stamp_after =
+ ComputeBitsShiftForBlockAndStamp(log_blocks_after);
+ int bits_shift_for_block_after = ComputeBitsShiftForBlock(log_blocks_after);
uint64_t block_size_before = (8 + num_group_id_bits_before);
uint64_t block_size_after = (8 + num_group_id_bits_after);
uint64_t block_size_total_after = (block_size_after << log_blocks_after) +
padding_;
@@ -701,8 +704,7 @@ Status SwissTable::grow_double() {
}
int ihalf = block_id_new & 1;
- uint8_t stamp_new =
- hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
+ uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) &
stamp_mask;
uint64_t group_id_bit_offs = j * num_group_id_bits_before;
uint64_t group_id =
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >>
3)) >>
@@ -744,8 +746,7 @@ Status SwissTable::grow_double() {
(util::SafeLoadAs<uint64_t>(block_base + 8 + (group_id_bit_offs >>
3)) >>
(group_id_bit_offs & 7)) &
group_id_mask_before;
- uint8_t stamp_new =
- hash >> ((bits_hash_ - log_blocks_after - bits_stamp_)) & stamp_mask;
+ uint8_t stamp_new = (hash >> bits_shift_for_block_and_stamp_after) &
stamp_mask;
uint8_t* block_base_new =
blocks_new->mutable_data() + block_id_new * block_size_after;
@@ -773,6 +774,8 @@ Status SwissTable::grow_double() {
blocks_ = std::move(blocks_new);
hashes_ = std::move(hashes_new_buffer);
log_blocks_ = log_blocks_after;
+ bits_shift_for_block_and_stamp_ = bits_shift_for_block_and_stamp_after;
+ bits_shift_for_block_ = bits_shift_for_block_after;
return Status::OK();
}
@@ -784,6 +787,8 @@ Status SwissTable::init(int64_t hardware_flags, MemoryPool*
pool, int log_blocks
log_minibatch_ = util::MiniBatch::kLogMiniBatchLength;
log_blocks_ = log_blocks;
+ bits_shift_for_block_and_stamp_ =
ComputeBitsShiftForBlockAndStamp(log_blocks_);
+ bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
int num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);
num_inserted_ = 0;
@@ -820,6 +825,8 @@ void SwissTable::cleanup() {
hashes_ = nullptr;
}
log_blocks_ = 0;
+ bits_shift_for_block_and_stamp_ =
ComputeBitsShiftForBlockAndStamp(log_blocks_);
+ bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
num_inserted_ = 0;
}
diff --git a/cpp/src/arrow/compute/key_map_internal.h
b/cpp/src/arrow/compute/key_map_internal.h
index a5e784a9e4..66a9957006 100644
--- a/cpp/src/arrow/compute/key_map_internal.h
+++ b/cpp/src/arrow/compute/key_map_internal.h
@@ -203,6 +203,23 @@ class ARROW_EXPORT SwissTable {
// Resize large hash tables when 75% full.
Status grow_double();
+ // When log_blocks is greater than 25, there will be overlapping bits
between block id
+ // and stamp within a 32-bit hash value. So we must check if this is the
case when
+ // right shifting a hash value to retrieve block id and stamp. The following
two
+ // functions derive the number of bits to right shift from the given
log_blocks.
+ static int ComputeBitsShiftForBlockAndStamp(int log_blocks) {
+ if (ARROW_PREDICT_FALSE(log_blocks + bits_stamp_ > bits_hash_)) {
+ return 0;
+ }
+ return bits_hash_ - log_blocks - bits_stamp_;
+ }
+ static int ComputeBitsShiftForBlock(int log_blocks) {
+ if (ARROW_PREDICT_FALSE(log_blocks + bits_stamp_ > bits_hash_)) {
+ return bits_hash_ - log_blocks;
+ }
+ return bits_stamp_;
+ }
+
// Number of hash bits stored in slots in a block.
// The highest bits of hash determine block id.
// The next set of highest bits is a "stamp" stored in a slot in a block.
@@ -214,6 +231,11 @@ class ARROW_EXPORT SwissTable {
int log_minibatch_;
// Base 2 log of the number of blocks
int log_blocks_ = 0;
+ // The following two variables are derived from log_blocks_ as log_blocks_
changes, and
+ // used in tight loops to avoid calling the ComputeXXX functions
(introducing a
+ // branching on whether log_blocks_ + bits_stamp_ > bits_hash_).
+ int bits_shift_for_block_and_stamp_ =
ComputeBitsShiftForBlockAndStamp(log_blocks_);
+ int bits_shift_for_block_ = ComputeBitsShiftForBlock(log_blocks_);
// Number of keys inserted into hash table
uint32_t num_inserted_ = 0;
@@ -271,8 +293,7 @@ void SwissTable::insert_into_empty_slot(uint32_t slot_id,
uint32_t hash,
constexpr uint64_t stamp_mask = 0x7f;
int start_slot = (slot_id & 7);
- int stamp =
- static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) &
stamp_mask);
+ int stamp = static_cast<int>((hash >> bits_shift_for_block_and_stamp_) &
stamp_mask);
uint64_t block_id = slot_id >> 3;
uint8_t* blockbase = blocks_->mutable_data() + num_block_bytes * block_id;
diff --git a/cpp/src/arrow/compute/key_map_internal_avx2.cc
b/cpp/src/arrow/compute/key_map_internal_avx2.cc
index 1a16603a0f..be54f7de63 100644
--- a/cpp/src/arrow/compute/key_map_internal_avx2.cc
+++ b/cpp/src/arrow/compute/key_map_internal_avx2.cc
@@ -45,10 +45,9 @@ int SwissTable::early_filter_imp_avx2_x8(const int
num_hashes, const uint32_t* h
// Calculate block index and hash stamp for a byte in a block
//
__m256i vhash = _mm256_loadu_si256(vhash_ptr + i);
- __m256i vblock_id = _mm256_srlv_epi32(
- vhash, _mm256_set1_epi32(bits_hash_ - bits_stamp_ - log_blocks_));
+ __m256i vblock_id = _mm256_srli_epi32(vhash,
bits_shift_for_block_and_stamp_);
__m256i vstamp = _mm256_and_si256(vblock_id, vstamp_mask);
- vblock_id = _mm256_srli_epi32(vblock_id, bits_stamp_);
+ vblock_id = _mm256_srli_epi32(vblock_id, bits_shift_for_block_);
// We now split inputs and process 4 at a time,
// in order to process 64-bit blocks
@@ -301,19 +300,15 @@ int SwissTable::early_filter_imp_avx2_x32(const int
num_hashes, const uint32_t*
_mm256_and_si256(vhash2,
_mm256_set1_epi32(0xffff0000)));
vhash1 = _mm256_or_si256(_mm256_srli_epi32(vhash1, 16),
_mm256_and_si256(vhash3,
_mm256_set1_epi32(0xffff0000)));
- __m256i vstamp_A = _mm256_and_si256(
- _mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 - log_blocks_ - 7)),
- _mm256_set1_epi16(0x7f));
- __m256i vstamp_B = _mm256_and_si256(
- _mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 - log_blocks_ - 7)),
- _mm256_set1_epi16(0x7f));
+ __m256i vstamp_A = _mm256_and_si256(_mm256_srli_epi32(vhash0, 16 -
log_blocks_ - 7),
+ _mm256_set1_epi16(0x7f));
+ __m256i vstamp_B = _mm256_and_si256(_mm256_srli_epi32(vhash1, 16 -
log_blocks_ - 7),
+ _mm256_set1_epi16(0x7f));
__m256i vstamp = _mm256_or_si256(vstamp_A, _mm256_slli_epi16(vstamp_B, 8));
- __m256i vblock_id_A =
- _mm256_and_si256(_mm256_srlv_epi32(vhash0, _mm256_set1_epi32(16 -
log_blocks_)),
- _mm256_set1_epi16(block_id_mask));
- __m256i vblock_id_B =
- _mm256_and_si256(_mm256_srlv_epi32(vhash1, _mm256_set1_epi32(16 -
log_blocks_)),
- _mm256_set1_epi16(block_id_mask));
+ __m256i vblock_id_A = _mm256_and_si256(_mm256_srli_epi32(vhash0, 16 -
log_blocks_),
+ _mm256_set1_epi16(block_id_mask));
+ __m256i vblock_id_B = _mm256_and_si256(_mm256_srli_epi32(vhash1, 16 -
log_blocks_),
+ _mm256_set1_epi16(block_id_mask));
__m256i vblock_id = _mm256_or_si256(vblock_id_A,
_mm256_slli_epi16(vblock_id_B, 8));
// Visit all block bytes in reverse order (overwriting data on multiple
matches)
@@ -392,16 +387,30 @@ int SwissTable::extract_group_ids_avx2(const int
num_keys, const uint32_t* hashe
} else {
for (int i = 0; i < num_keys / unroll; ++i) {
__m256i hash = _mm256_loadu_si256(reinterpret_cast<const
__m256i*>(hashes) + i);
+ // Extend hash and local_slot to 64-bit to compute 64-bit group id
offsets to
+ // gather from. This is to prevent index overflow issues in GH-44513.
+ // NB: Use zero-extend conversion for unsigned hash.
+ __m256i hash_lo = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(hash));
+ __m256i hash_hi = _mm256_cvtepu32_epi64(_mm256_extracti128_si256(hash,
1));
__m256i local_slot =
_mm256_set1_epi64x(reinterpret_cast<const
uint64_t*>(local_slots)[i]);
- local_slot = _mm256_shuffle_epi8(
- local_slot, _mm256_setr_epi32(0x80808000, 0x80808001, 0x80808002,
0x80808003,
- 0x80808004, 0x80808005, 0x80808006,
0x80808007));
- local_slot = _mm256_mullo_epi32(local_slot,
_mm256_set1_epi32(byte_size));
- __m256i pos = _mm256_srlv_epi32(hash, _mm256_set1_epi32(bits_hash_ -
log_blocks_));
- pos = _mm256_mullo_epi32(pos, _mm256_set1_epi32(byte_multiplier));
- pos = _mm256_add_epi32(pos, local_slot);
- __m256i group_id = _mm256_i32gather_epi32(elements, pos, 1);
+ __m256i local_slot_lo = _mm256_shuffle_epi8(
+ local_slot, _mm256_setr_epi32(0x80808000, 0x80808080, 0x80808001,
0x80808080,
+ 0x80808002, 0x80808080, 0x80808003,
0x80808080));
+ __m256i local_slot_hi = _mm256_shuffle_epi8(
+ local_slot, _mm256_setr_epi32(0x80808004, 0x80808080, 0x80808005,
0x80808080,
+ 0x80808006, 0x80808080, 0x80808007,
0x80808080));
+ local_slot_lo = _mm256_mul_epu32(local_slot_lo,
_mm256_set1_epi32(byte_size));
+ local_slot_hi = _mm256_mul_epu32(local_slot_hi,
_mm256_set1_epi32(byte_size));
+ __m256i pos_lo = _mm256_srli_epi64(hash_lo, bits_hash_ - log_blocks_);
+ __m256i pos_hi = _mm256_srli_epi64(hash_hi, bits_hash_ - log_blocks_);
+ pos_lo = _mm256_mul_epu32(pos_lo, _mm256_set1_epi32(byte_multiplier));
+ pos_hi = _mm256_mul_epu32(pos_hi, _mm256_set1_epi32(byte_multiplier));
+ pos_lo = _mm256_add_epi64(pos_lo, local_slot_lo);
+ pos_hi = _mm256_add_epi64(pos_hi, local_slot_hi);
+ __m128i group_id_lo = _mm256_i64gather_epi32(elements, pos_lo, 1);
+ __m128i group_id_hi = _mm256_i64gather_epi32(elements, pos_hi, 1);
+ __m256i group_id = _mm256_set_m128i(group_id_hi, group_id_lo);
group_id = _mm256_and_si256(group_id, _mm256_set1_epi32(mask));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(out_group_ids) + i,
group_id);
}