This is an automated email from the ASF dual-hosted git repository.
zanmato pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new b818560529 GH-45334: [C++][Acero] Fix swiss join overflow issues in
row offset calculation for fixed length and null masks (#45336)
b818560529 is described below
commit b8185605295e55b1dc8740684351403f1860d87f
Author: Rossi Sun <[email protected]>
AuthorDate: Tue Jan 28 00:25:29 2025 +0800
GH-45334: [C++][Acero] Fix swiss join overflow issues in row offset
calculation for fixed length and null masks (#45336)
### Rationale for this change
#45334
### What changes are included in this PR?
1. An all-mighty test case that can effectively reveal all the bugs
mentioned in the issue;
2. Other than directly fixing the bugs (actually simply casting to 64-bit
somewhere in the multiplication will do), I did some refinement to the buffer
accessors of the row table, in order to eliminate more potential similar issues
(which I believe do exist):
1. `null_masks()` -> `null_masks(row_id)` which does overflow-safe
indexing inside;
2. `is_null(row_id, col_pos)` which does overflow-safe indexing and
directly gets the bit of the column;
3. `data(1)` -> `fixed_length_rows(row_id)` which first asserts the row
table being fixed-length, then does overflow-safe indexing inside;
4. `data(2)` -> `var_length_rows()` which only asserts the row table
being var-length. It is supposed to be paired by the `offsets()` (which is
already 64-bit by #43389 );
5. The `data(0/1/2)` members are made private.
3. The AVX2 specializations are fixed individually by using 64-bit
multiplication and indexing.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
None.
* GitHub Issue: #45334
Authored-by: Rossi Sun <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
---
cpp/src/arrow/acero/hash_join_node_test.cc | 99 ++++++++++++++++++++++
cpp/src/arrow/acero/swiss_join.cc | 36 +++++---
cpp/src/arrow/acero/swiss_join_avx2.cc | 34 ++------
cpp/src/arrow/acero/swiss_join_internal.h | 14 ++-
cpp/src/arrow/compute/row/compare_internal.cc | 16 ++--
cpp/src/arrow/compute/row/compare_internal_avx2.cc | 81 ++++--------------
cpp/src/arrow/compute/row/compare_test.cc | 6 +-
cpp/src/arrow/compute/row/encode_internal.cc | 54 +++++-------
cpp/src/arrow/compute/row/encode_internal.h | 10 +--
cpp/src/arrow/compute/row/encode_internal_avx2.cc | 15 ++--
cpp/src/arrow/compute/row/row_internal.cc | 10 ++-
cpp/src/arrow/compute/row/row_internal.h | 61 +++++++++----
cpp/src/arrow/compute/row/row_test.cc | 11 ++-
cpp/src/arrow/compute/row/row_util_avx2_internal.h | 64 ++++++++++++++
14 files changed, 313 insertions(+), 198 deletions(-)
diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc
b/cpp/src/arrow/acero/hash_join_node_test.cc
index 94504ccc9b..654fd59c45 100644
--- a/cpp/src/arrow/acero/hash_join_node_test.cc
+++ b/cpp/src/arrow/acero/hash_join_node_test.cc
@@ -3449,5 +3449,104 @@ TEST(HashJoin,
LARGE_MEMORY_TEST(BuildSideOver4GBVarLength)) {
num_batches_left * num_rows_per_batch_left *
num_batches_right);
}
+// GH-45334: The row ids of the matching rows on the right side (the build
side) are very
+// big, causing the index calculation overflow.
+TEST(HashJoin, BuildSideLargeRowIds) {
+ GTEST_SKIP() << "Test disabled due to excessively time and resource
consuming, "
+ "for local debugging only.";
+
+ // A fair amount of match rows to trigger both SIMD and non-SIMD code paths.
+ const int64_t num_match_rows = 35;
+ const int64_t num_rows_per_match_batch = 35;
+ const int64_t num_match_batches = num_match_rows / num_rows_per_match_batch;
+
+ const int64_t num_unmatch_rows_large = 720898048;
+ const int64_t num_rows_per_unmatch_batch_large = 352001;
+ const int64_t num_unmatch_batches_large =
+ num_unmatch_rows_large / num_rows_per_unmatch_batch_large;
+
+ auto schema_small =
+ schema({field("small_key", int64()), field("small_payload", int64())});
+ auto schema_large =
+ schema({field("large_key", int64()), field("large_payload", int64())});
+
+ // A carefully chosen key value which hashes to 0xFFFFFFFE, making the match
rows to be
+ // placed at higher address of the row table.
+ const int64_t match_key = 289339070;
+ const int64_t match_payload = 42;
+
+ // Match arrays of length num_rows_per_match_batch.
+ ASSERT_OK_AND_ASSIGN(
+ auto match_key_arr,
+ Constant(MakeScalar(match_key))->Generate(num_rows_per_match_batch));
+ ASSERT_OK_AND_ASSIGN(
+ auto match_payload_arr,
+ Constant(MakeScalar(match_payload))->Generate(num_rows_per_match_batch));
+ // Append 1 row of null to trigger null processing code paths.
+ ASSERT_OK_AND_ASSIGN(auto null_arr, MakeArrayOfNull(int64(), 1));
+ ASSERT_OK_AND_ASSIGN(match_key_arr, Concatenate({match_key_arr, null_arr}));
+ ASSERT_OK_AND_ASSIGN(match_payload_arr, Concatenate({match_payload_arr,
null_arr}));
+ // Match batch.
+ ExecBatch match_batch({match_key_arr, match_payload_arr},
num_rows_per_match_batch + 1);
+
+ // Small batch.
+ ExecBatch batch_small = match_batch;
+
+ // Large unmatch batches.
+ const int64_t seed = 42;
+ std::vector<ExecBatch> unmatch_batches_large;
+ unmatch_batches_large.reserve(num_unmatch_batches_large);
+ ASSERT_OK_AND_ASSIGN(auto unmatch_payload_arr_large,
+ MakeArrayOfNull(int64(),
num_rows_per_unmatch_batch_large));
+ int64_t unmatch_range_per_batch =
+ (std::numeric_limits<int64_t>::max() - match_key) /
num_unmatch_batches_large;
+ for (int i = 0; i < num_unmatch_batches_large; ++i) {
+ auto unmatch_key_arr_large = RandomArrayGenerator(seed).Int64(
+ num_rows_per_unmatch_batch_large,
+ /*min=*/match_key + 1 + i * unmatch_range_per_batch,
+ /*max=*/match_key + 1 + (i + 1) * unmatch_range_per_batch);
+ unmatch_batches_large.push_back(
+ ExecBatch({unmatch_key_arr_large, unmatch_payload_arr_large},
+ num_rows_per_unmatch_batch_large));
+ }
+ // Large match batch.
+ ExecBatch match_batch_large = match_batch;
+
+ // Batches with schemas.
+ auto batches_small = BatchesWithSchema{
+ std::vector<ExecBatch>(num_match_batches, batch_small), schema_small};
+ auto batches_large = BatchesWithSchema{std::move(unmatch_batches_large),
schema_large};
+ for (int i = 0; i < num_match_batches; i++) {
+ batches_large.batches.push_back(match_batch_large);
+ }
+
+ Declaration source_small{
+ "exec_batch_source",
+ ExecBatchSourceNodeOptions(batches_small.schema, batches_small.batches)};
+ Declaration source_large{
+ "exec_batch_source",
+ ExecBatchSourceNodeOptions(batches_large.schema, batches_large.batches)};
+
+ HashJoinNodeOptions join_opts(JoinType::INNER, /*left_keys=*/{"small_key"},
+ /*right_keys=*/{"large_key"});
+ Declaration join{
+ "hashjoin", {std::move(source_small), std::move(source_large)},
join_opts};
+
+ // Join should emit num_match_rows * num_match_rows rows.
+ ASSERT_OK_AND_ASSIGN(auto batches_result,
DeclarationToExecBatches(std::move(join)));
+ Declaration result{"exec_batch_source",
+
ExecBatchSourceNodeOptions(std::move(batches_result.schema),
+
std::move(batches_result.batches))};
+ AssertRowCountEq(result, num_match_rows * num_match_rows);
+
+ // All rows should be match_key/payload.
+ auto predicate = and_({equal(field_ref("small_key"), literal(match_key)),
+ equal(field_ref("small_payload"),
literal(match_payload)),
+ equal(field_ref("large_key"), literal(match_key)),
+ equal(field_ref("large_payload"),
literal(match_payload))});
+ Declaration filter{"filter", {result},
FilterNodeOptions{std::move(predicate)}};
+ AssertRowCountEq(std::move(filter), num_match_rows * num_match_rows);
+}
+
} // namespace acero
} // namespace arrow
diff --git a/cpp/src/arrow/acero/swiss_join.cc
b/cpp/src/arrow/acero/swiss_join.cc
index fc3be1b462..85e14ac469 100644
--- a/cpp/src/arrow/acero/swiss_join.cc
+++ b/cpp/src/arrow/acero/swiss_join.cc
@@ -477,14 +477,15 @@ void RowArrayMerge::CopyFixedLength(RowTableImpl* target,
const RowTableImpl& so
const int64_t* source_rows_permutation) {
int64_t num_source_rows = source.length();
- int64_t fixed_length = target->metadata().fixed_length;
+ uint32_t fixed_length = target->metadata().fixed_length;
// Permutation of source rows is optional. Without permutation all that is
// needed is memcpy.
//
if (!source_rows_permutation) {
- memcpy(target->mutable_data(1) + fixed_length * first_target_row_id,
source.data(1),
- fixed_length * num_source_rows);
+ DCHECK_LE(first_target_row_id, std::numeric_limits<uint32_t>::max());
+
memcpy(target->mutable_fixed_length_rows(static_cast<uint32_t>(first_target_row_id)),
+ source.fixed_length_rows(/*row_id=*/0), fixed_length *
num_source_rows);
} else {
// Row length must be a multiple of 64-bits due to enforced alignment.
// Loop for each output row copying a fixed number of 64-bit words.
@@ -494,10 +495,13 @@ void RowArrayMerge::CopyFixedLength(RowTableImpl* target,
const RowTableImpl& so
int64_t num_words_per_row = fixed_length / sizeof(uint64_t);
for (int64_t i = 0; i < num_source_rows; ++i) {
int64_t source_row_id = source_rows_permutation[i];
+ DCHECK_LE(source_row_id, std::numeric_limits<uint32_t>::max());
const uint64_t* source_row_ptr = reinterpret_cast<const uint64_t*>(
- source.data(1) + fixed_length * source_row_id);
+ source.fixed_length_rows(static_cast<uint32_t>(source_row_id)));
+ int64_t target_row_id = first_target_row_id + i;
+ DCHECK_LE(target_row_id, std::numeric_limits<uint32_t>::max());
uint64_t* target_row_ptr = reinterpret_cast<uint64_t*>(
- target->mutable_data(1) + fixed_length * (first_target_row_id + i));
+
target->mutable_fixed_length_rows(static_cast<uint32_t>(target_row_id)));
for (int64_t word = 0; word < num_words_per_row; ++word) {
target_row_ptr[word] = source_row_ptr[word];
@@ -529,16 +533,16 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl*
target, const RowTableImpl&
// We can simply memcpy bytes of rows if their order has not changed.
//
- memcpy(target->mutable_data(2) + target_offsets[first_target_row_id],
source.data(2),
- source_offsets[num_source_rows] - source_offsets[0]);
+ memcpy(target->mutable_var_length_rows() +
target_offsets[first_target_row_id],
+ source.var_length_rows(), source_offsets[num_source_rows] -
source_offsets[0]);
} else {
int64_t target_row_offset = first_target_row_offset;
- uint64_t* target_row_ptr =
- reinterpret_cast<uint64_t*>(target->mutable_data(2) +
target_row_offset);
+ uint64_t* target_row_ptr = reinterpret_cast<uint64_t*>(
+ target->mutable_var_length_rows() + target_row_offset);
for (int64_t i = 0; i < num_source_rows; ++i) {
int64_t source_row_id = source_rows_permutation[i];
const uint64_t* source_row_ptr = reinterpret_cast<const uint64_t*>(
- source.data(2) + source_offsets[source_row_id]);
+ source.var_length_rows() + source_offsets[source_row_id]);
int64_t length = source_offsets[source_row_id + 1] -
source_offsets[source_row_id];
// Though the row offset is 64-bit, the length of a single row must be
32-bit as
// required by current row table implementation.
@@ -564,14 +568,18 @@ void RowArrayMerge::CopyNulls(RowTableImpl* target, const
RowTableImpl& source,
const int64_t* source_rows_permutation) {
int64_t num_source_rows = source.length();
int num_bytes_per_row = target->metadata().null_masks_bytes_per_row;
- uint8_t* target_nulls = target->null_masks() + num_bytes_per_row *
first_target_row_id;
+ DCHECK_LE(first_target_row_id, std::numeric_limits<uint32_t>::max());
+ uint8_t* target_nulls =
+ target->mutable_null_masks(static_cast<uint32_t>(first_target_row_id));
if (!source_rows_permutation) {
- memcpy(target_nulls, source.null_masks(), num_bytes_per_row *
num_source_rows);
+ memcpy(target_nulls, source.null_masks(/*row_id=*/0),
+ num_bytes_per_row * num_source_rows);
} else {
- for (int64_t i = 0; i < num_source_rows; ++i) {
+ for (uint32_t i = 0; i < num_source_rows; ++i) {
int64_t source_row_id = source_rows_permutation[i];
+ DCHECK_LE(source_row_id, std::numeric_limits<uint32_t>::max());
const uint8_t* source_nulls =
- source.null_masks() + num_bytes_per_row * source_row_id;
+ source.null_masks(static_cast<uint32_t>(source_row_id));
for (int64_t byte = 0; byte < num_bytes_per_row; ++byte) {
*target_nulls++ = *source_nulls++;
}
diff --git a/cpp/src/arrow/acero/swiss_join_avx2.cc
b/cpp/src/arrow/acero/swiss_join_avx2.cc
index 1d6b7eda6e..deeee2a4e1 100644
--- a/cpp/src/arrow/acero/swiss_join_avx2.cc
+++ b/cpp/src/arrow/acero/swiss_join_avx2.cc
@@ -16,6 +16,7 @@
// under the License.
#include "arrow/acero/swiss_join_internal.h"
+#include "arrow/compute/row/row_util_avx2_internal.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/simd.h"
@@ -46,7 +47,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows,
int column_id, int nu
if (!is_fixed_length_column) {
int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id);
- const uint8_t* row_ptr_base = rows.data(2);
+ const uint8_t* row_ptr_base = rows.var_length_rows();
const RowTableImpl::offset_type* row_offsets = rows.offsets();
auto row_offsets_i64 =
reinterpret_cast<const arrow::util::int64_for_gather_t*>(row_offsets);
@@ -172,7 +173,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows,
int column_id, int nu
if (is_fixed_length_row) {
// Case 3: This is a fixed length column in fixed length row
//
- const uint8_t* row_ptr_base = rows.data(1);
+ const uint8_t* row_ptr_base = rows.fixed_length_rows(/*row_id=*/0);
for (int i = 0; i < num_rows / kUnroll; ++i) {
// Load 8 32-bit row ids.
__m256i row_id =
@@ -197,7 +198,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows,
int column_id, int nu
} else {
// Case 4: This is a fixed length column in varying length row
//
- const uint8_t* row_ptr_base = rows.data(2);
+ const uint8_t* row_ptr_base = rows.var_length_rows();
const RowTableImpl::offset_type* row_offsets = rows.offsets();
auto row_offsets_i64 =
reinterpret_cast<const
arrow::util::int64_for_gather_t*>(row_offsets);
@@ -237,31 +238,12 @@ int RowArrayAccessor::VisitNulls_avx2(const RowTableImpl&
rows, int column_id,
//
constexpr int kUnroll = 8;
- const uint8_t* null_masks = rows.null_masks();
- __m256i null_bits_per_row =
- _mm256_set1_epi32(8 * rows.metadata().null_masks_bytes_per_row);
- __m256i pos_after_encoding =
- _mm256_set1_epi32(rows.metadata().pos_after_encoding(column_id));
+ uint32_t pos_after_encoding = rows.metadata().pos_after_encoding(column_id);
for (int i = 0; i < num_rows / kUnroll; ++i) {
__m256i row_id = _mm256_loadu_si256(reinterpret_cast<const
__m256i*>(row_ids) + i);
- __m256i bit_id = _mm256_mullo_epi32(row_id, null_bits_per_row);
- bit_id = _mm256_add_epi32(bit_id, pos_after_encoding);
- __m256i bytes = _mm256_i32gather_epi32(reinterpret_cast<const
int*>(null_masks),
- _mm256_srli_epi32(bit_id, 3), 1);
- __m256i bit_in_word = _mm256_sllv_epi32(
- _mm256_set1_epi32(1), _mm256_and_si256(bit_id, _mm256_set1_epi32(7)));
- // `result` will contain one 32-bit word per tested null bit, either
0xffffffff if the
- // null bit was set or 0 if it was unset.
- __m256i result =
- _mm256_cmpeq_epi32(_mm256_and_si256(bytes, bit_in_word), bit_in_word);
- // NB: Be careful about sign-extension when casting the return value of
- // _mm256_movemask_epi8 (signed 32-bit) to unsigned 64-bit, which will
pollute the
- // higher bits of the following OR.
- uint32_t null_bytes_lo = static_cast<uint32_t>(
-
_mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(result))));
- uint64_t null_bytes_hi =
-
_mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(result,
1)));
- uint64_t null_bytes = null_bytes_lo | (null_bytes_hi << 32);
+ __m256i null32 = GetNullBitInt32(rows, pos_after_encoding, row_id);
+ null32 = _mm256_cmpeq_epi32(null32, _mm256_set1_epi32(1));
+ uint64_t null_bytes = arrow::compute::Cmp32To8(null32);
process_8_values_fn(i * kUnroll, null_bytes);
}
diff --git a/cpp/src/arrow/acero/swiss_join_internal.h
b/cpp/src/arrow/acero/swiss_join_internal.h
index f2f3ac5b1b..85f443b032 100644
--- a/cpp/src/arrow/acero/swiss_join_internal.h
+++ b/cpp/src/arrow/acero/swiss_join_internal.h
@@ -72,7 +72,7 @@ class RowArrayAccessor {
if (!is_fixed_length_column) {
int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id);
- const uint8_t* row_ptr_base = rows.data(2);
+ const uint8_t* row_ptr_base = rows.var_length_rows();
const RowTableImpl::offset_type* row_offsets = rows.offsets();
uint32_t field_offset_within_row, field_length;
@@ -108,22 +108,21 @@ class RowArrayAccessor {
if (field_length == 0) {
field_length = 1;
}
- uint32_t row_length = rows.metadata().fixed_length;
bool is_fixed_length_row = rows.metadata().is_fixed_length;
if (is_fixed_length_row) {
// Case 3: This is a fixed length column in a fixed length row
//
- const uint8_t* row_ptr_base = rows.data(1) + field_offset_within_row;
for (int i = 0; i < num_rows; ++i) {
uint32_t row_id = row_ids[i];
- const uint8_t* row_ptr = row_ptr_base + row_length * row_id;
+ const uint8_t* row_ptr =
+ rows.fixed_length_rows(row_id) + field_offset_within_row;
process_value_fn(i, row_ptr, field_length);
}
} else {
// Case 4: This is a fixed length column in a varying length row
//
- const uint8_t* row_ptr_base = rows.data(2) + field_offset_within_row;
+ const uint8_t* row_ptr_base = rows.var_length_rows() +
field_offset_within_row;
const RowTableImpl::offset_type* row_offsets = rows.offsets();
for (int i = 0; i < num_rows; ++i) {
uint32_t row_id = row_ids[i];
@@ -142,13 +141,10 @@ class RowArrayAccessor {
template <class PROCESS_VALUE_FN>
static void VisitNulls(const RowTableImpl& rows, int column_id, int num_rows,
const uint32_t* row_ids, PROCESS_VALUE_FN
process_value_fn) {
- const uint8_t* null_masks = rows.null_masks();
- uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
uint32_t pos_after_encoding =
rows.metadata().pos_after_encoding(column_id);
for (int i = 0; i < num_rows; ++i) {
uint32_t row_id = row_ids[i];
- int64_t bit_id = row_id * null_mask_num_bytes * 8 + pos_after_encoding;
- process_value_fn(i, bit_util::GetBit(null_masks, bit_id) ? 0xff : 0);
+ process_value_fn(i, rows.is_null(row_id, pos_after_encoding) ? 0xff : 0);
}
}
diff --git a/cpp/src/arrow/compute/row/compare_internal.cc
b/cpp/src/arrow/compute/row/compare_internal.cc
index 5e1a87b795..b7a01ea75a 100644
--- a/cpp/src/arrow/compute/row/compare_internal.cc
+++ b/cpp/src/arrow/compute/row/compare_internal.cc
@@ -55,13 +55,10 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col,
uint32_t num_rows_to_com
if (!col.data(0)) {
// Remove rows from the result for which the column value is a null
- const uint8_t* null_masks = rows.null_masks();
- uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
- int64_t bitid = irow_right * null_mask_num_bytes * 8 + null_bit_id;
- match_bytevector[i] &= (bit_util::GetBit(null_masks, bitid) ? 0 : 0xff);
+ match_bytevector[i] &= (rows.is_null(irow_right, null_bit_id) ? 0 :
0xff);
}
} else if (!rows.has_any_nulls(ctx)) {
// Remove rows from the result for which the column value on left side is
@@ -74,15 +71,12 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col,
uint32_t num_rows_to_com
bit_util::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0xff :
0;
}
} else {
- const uint8_t* null_masks = rows.null_masks();
- uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
const uint8_t* non_nulls = col.data(0);
ARROW_DCHECK(non_nulls);
for (uint32_t i = num_processed; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
- int64_t bitid_right = irow_right * null_mask_num_bytes * 8 + null_bit_id;
- int right_null = bit_util::GetBit(null_masks, bitid_right) ? 0xff : 0;
+ int right_null = rows.is_null(irow_right, null_bit_id) ? 0xff : 0;
int left_null =
bit_util::GetBit(non_nulls, irow_left + col.bit_offset(0)) ? 0 :
0xff;
match_bytevector[i] |= left_null & right_null;
@@ -101,7 +95,7 @@ void KeyCompare::CompareBinaryColumnToRowHelper(
if (is_fixed_length) {
uint32_t fixed_length = rows.metadata().fixed_length;
const uint8_t* rows_left = col.data(1);
- const uint8_t* rows_right = rows.data(1);
+ const uint8_t* rows_right = rows.fixed_length_rows(/*row_id=*/0);
for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
// irow_right is used to index into row data so promote to the row
offset type.
@@ -113,7 +107,7 @@ void KeyCompare::CompareBinaryColumnToRowHelper(
} else {
const uint8_t* rows_left = col.data(1);
const RowTableImpl::offset_type* offsets_right = rows.offsets();
- const uint8_t* rows_right = rows.data(2);
+ const uint8_t* rows_right = rows.var_length_rows();
for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
@@ -246,7 +240,7 @@ void KeyCompare::CompareVarBinaryColumnToRowHelper(
const uint32_t* offsets_left = col.offsets();
const RowTableImpl::offset_type* offsets_right = rows.offsets();
const uint8_t* rows_left = col.data(2);
- const uint8_t* rows_right = rows.data(2);
+ const uint8_t* rows_right = rows.var_length_rows();
for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
diff --git a/cpp/src/arrow/compute/row/compare_internal_avx2.cc
b/cpp/src/arrow/compute/row/compare_internal_avx2.cc
index 9f6e1adfe2..8af84ac6b2 100644
--- a/cpp/src/arrow/compute/row/compare_internal_avx2.cc
+++ b/cpp/src/arrow/compute/row/compare_internal_avx2.cc
@@ -16,6 +16,7 @@
// under the License.
#include "arrow/compute/row/compare_internal.h"
+#include "arrow/compute/row/row_util_avx2_internal.h"
#include "arrow/compute/util.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/simd.h"
@@ -49,9 +50,6 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2(
if (!col.data(0)) {
// Remove rows from the result for which the column value is a null
- const uint8_t* null_masks = rows.null_masks();
- uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
-
uint32_t num_processed = 0;
constexpr uint32_t unroll = 8;
for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
@@ -64,21 +62,9 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2(
irow_right =
_mm256_loadu_si256(reinterpret_cast<const
__m256i*>(left_to_right_map) + i);
}
- __m256i bitid =
- _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes
* 8));
- bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(null_bit_id));
- __m256i right =
- _mm256_i32gather_epi32((const int*)null_masks,
_mm256_srli_epi32(bitid, 3), 1);
- right = _mm256_and_si256(
- _mm256_set1_epi32(1),
- _mm256_srlv_epi32(right, _mm256_and_si256(bitid,
_mm256_set1_epi32(7))));
+ __m256i right = GetNullBitInt32(rows, null_bit_id, irow_right);
__m256i cmp = _mm256_cmpeq_epi32(right, _mm256_setzero_si256());
- uint32_t result_lo =
-
_mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
- uint32_t result_hi =
-
_mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
- reinterpret_cast<uint64_t*>(match_bytevector)[i] &=
- result_lo | (static_cast<uint64_t>(result_hi) << 32);
+ reinterpret_cast<uint64_t*>(match_bytevector)[i] &= Cmp32To8(cmp);
}
num_processed = num_rows_to_compare / unroll * unroll;
return num_processed;
@@ -107,18 +93,11 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2(
__m256i bits = _mm256_setr_epi32(1, 2, 4, 8, 16, 32, 64, 128);
cmp = _mm256_cmpeq_epi32(_mm256_and_si256(left, bits), bits);
}
- uint32_t result_lo =
-
_mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
- uint32_t result_hi =
-
_mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp, 1)));
- reinterpret_cast<uint64_t*>(match_bytevector)[i] &=
- result_lo | (static_cast<uint64_t>(result_hi) << 32);
- num_processed = num_rows_to_compare / unroll * unroll;
+ reinterpret_cast<uint64_t*>(match_bytevector)[i] &= Cmp32To8(cmp);
}
+ num_processed = num_rows_to_compare / unroll * unroll;
return num_processed;
} else {
- const uint8_t* null_masks = rows.null_masks();
- uint32_t null_mask_num_bytes = rows.metadata().null_masks_bytes_per_row;
const uint8_t* non_nulls = col.data(0);
ARROW_DCHECK(non_nulls);
@@ -147,29 +126,11 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2(
left_null =
_mm256_cmpeq_epi32(_mm256_and_si256(left, bits),
_mm256_setzero_si256());
}
- __m256i bitid =
- _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(null_mask_num_bytes
* 8));
- bitid = _mm256_add_epi32(bitid, _mm256_set1_epi32(null_bit_id));
- __m256i right =
- _mm256_i32gather_epi32((const int*)null_masks,
_mm256_srli_epi32(bitid, 3), 1);
- right = _mm256_and_si256(
- _mm256_set1_epi32(1),
- _mm256_srlv_epi32(right, _mm256_and_si256(bitid,
_mm256_set1_epi32(7))));
+ __m256i right = GetNullBitInt32(rows, null_bit_id, irow_right);
__m256i right_null = _mm256_cmpeq_epi32(right, _mm256_set1_epi32(1));
- uint64_t left_null_64 =
- static_cast<uint32_t>(_mm256_movemask_epi8(
- _mm256_cvtepi32_epi64(_mm256_castsi256_si128(left_null)))) |
- (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(left_null, 1)))))
- << 32);
-
- uint64_t right_null_64 =
- static_cast<uint32_t>(_mm256_movemask_epi8(
- _mm256_cvtepi32_epi64(_mm256_castsi256_si128(right_null)))) |
- (static_cast<uint64_t>(static_cast<uint32_t>(_mm256_movemask_epi8(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_null,
1)))))
- << 32);
+ uint64_t left_null_64 = Cmp32To8(left_null);
+ uint64_t right_null_64 = Cmp32To8(right_null);
reinterpret_cast<uint64_t*>(match_bytevector)[i] |= left_null_64 &
right_null_64;
reinterpret_cast<uint64_t*>(match_bytevector)[i] &= ~(left_null_64 ^
right_null_64);
@@ -189,7 +150,7 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2(
if (is_fixed_length) {
uint32_t fixed_length = rows.metadata().fixed_length;
const uint8_t* rows_left = col.data(1);
- const uint8_t* rows_right = rows.data(1);
+ const uint8_t* rows_right = rows.fixed_length_rows(/*row_id=*/0);
constexpr uint32_t unroll = 8;
__m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
@@ -234,7 +195,7 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2(
} else {
const uint8_t* rows_left = col.data(1);
const RowTableImpl::offset_type* offsets_right = rows.offsets();
- const uint8_t* rows_right = rows.data(2);
+ const uint8_t* rows_right = rows.var_length_rows();
constexpr uint32_t unroll = 8;
__m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
for (uint32_t i = 0; i < num_rows_to_compare / unroll; ++i) {
@@ -321,12 +282,7 @@ inline uint64_t CompareSelected8_avx2(const uint8_t*
left_base, const uint8_t* r
__m256i cmp = _mm256_cmpeq_epi32(left, right);
- uint32_t result_lo =
- _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
- uint32_t result_hi =
- _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp,
1)));
-
- return result_lo | (static_cast<uint64_t>(result_hi) << 32);
+ return Cmp32To8(cmp);
}
template <int column_width>
@@ -372,12 +328,7 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base,
const uint8_t* right_bas
__m256i cmp = _mm256_cmpeq_epi32(left, right);
- uint32_t result_lo =
- _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp)));
- uint32_t result_hi =
- _mm256_movemask_epi8(_mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp,
1)));
-
- return result_lo | (static_cast<uint64_t>(result_hi) << 32);
+ return Cmp32To8(cmp);
}
template <bool use_selection>
@@ -402,9 +353,9 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t*
left_base, const uint8_t* rig
reinterpret_cast<const arrow::util::int64_for_gather_t*>(right_base);
__m256i right_lo = _mm256_i64gather_epi64(right_base_i64, offset_right_lo,
1);
__m256i right_hi = _mm256_i64gather_epi64(right_base_i64, offset_right_hi,
1);
- uint32_t result_lo = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_lo,
right_lo));
- uint32_t result_hi = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_hi,
right_hi));
- return result_lo | (static_cast<uint64_t>(result_hi) << 32);
+ __m256i cmp_lo = _mm256_cmpeq_epi64(left_lo, right_lo);
+ __m256i cmp_hi = _mm256_cmpeq_epi64(left_hi, right_hi);
+ return Cmp64To8(cmp_lo, cmp_hi);
}
template <bool use_selection>
@@ -554,7 +505,7 @@ void KeyCompare::CompareVarBinaryColumnToRowImp_avx2(
const uint32_t* offsets_left = col.offsets();
const RowTableImpl::offset_type* offsets_right = rows.offsets();
const uint8_t* rows_left = col.data(2);
- const uint8_t* rows_right = rows.data(2);
+ const uint8_t* rows_right = rows.var_length_rows();
for (uint32_t i = 0; i < num_rows_to_compare; ++i) {
uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i;
uint32_t irow_right = left_to_right_map[irow_left];
diff --git a/cpp/src/arrow/compute/row/compare_test.cc
b/cpp/src/arrow/compute/row/compare_test.cc
index 5e8ee7c58a..2b8f4d9756 100644
--- a/cpp/src/arrow/compute/row/compare_test.cc
+++ b/cpp/src/arrow/compute/row/compare_test.cc
@@ -327,7 +327,7 @@ TEST(KeyCompare,
LARGE_MEMORY_TEST(CompareColumnsToRowsOver2GB)) {
ASSERT_OK_AND_ASSIGN(RowTableImpl row_table_right,
MakeRowTableFromExecBatch(batch_left));
// The row table must contain an offset buffer.
- ASSERT_NE(row_table_right.data(2), NULLPTR);
+ ASSERT_NE(row_table_right.var_length_rows(), NULLPTR);
// The whole point of this test.
ASSERT_GT(row_table_right.offsets()[num_rows - 1], k2GB);
@@ -387,7 +387,7 @@ TEST(KeyCompare,
LARGE_MEMORY_TEST(CompareColumnsToRowsOver4GBFixedLength)) {
RepeatRowTableUntil(MakeRowTableFromExecBatch(batch_left).ValueUnsafe(),
num_rows_row_table));
// The row table must not contain a third buffer.
- ASSERT_EQ(row_table_right.data(2), NULLPTR);
+ ASSERT_EQ(row_table_right.var_length_rows(), NULLPTR);
// The row data must be greater than 4GB.
ASSERT_GT(row_table_right.buffer_size(1), k4GB);
@@ -460,7 +460,7 @@ TEST(KeyCompare,
LARGE_MEMORY_TEST(CompareColumnsToRowsOver4GBVarLength)) {
RepeatRowTableUntil(MakeRowTableFromExecBatch(batch_left).ValueUnsafe(),
num_rows_row_table));
// The row table must contain an offset buffer.
- ASSERT_NE(row_table_right.data(2), NULLPTR);
+ ASSERT_NE(row_table_right.var_length_rows(), NULLPTR);
// At least the last row should be located at over 4GB.
ASSERT_GT(row_table_right.offsets()[num_rows_row_table - 1], k4GB);
diff --git a/cpp/src/arrow/compute/row/encode_internal.cc
b/cpp/src/arrow/compute/row/encode_internal.cc
index 127d43021d..0e2720a286 100644
--- a/cpp/src/arrow/compute/row/encode_internal.cc
+++ b/cpp/src/arrow/compute/row/encode_internal.cc
@@ -260,36 +260,32 @@ void EncoderInteger::Decode(uint32_t start_row, uint32_t
num_rows,
col_prep.metadata().fixed_length == rows.metadata().fixed_length) {
DCHECK_EQ(offset_within_row, 0);
uint32_t row_size = rows.metadata().fixed_length;
- memcpy(col_prep.mutable_data(1), rows.data(1) + start_row * row_size,
- num_rows * row_size);
+ memcpy(col_prep.mutable_data(1), rows.fixed_length_rows(start_row),
+ static_cast<int64_t>(num_rows) * row_size);
} else if (rows.metadata().is_fixed_length) {
- uint32_t row_size = rows.metadata().fixed_length;
- const uint8_t* row_base =
- rows.data(1) + static_cast<RowTableImpl::offset_type>(start_row) *
row_size;
- row_base += offset_within_row;
uint8_t* col_base = col_prep.mutable_data(1);
switch (col_prep.metadata().fixed_length) {
case 1:
for (uint32_t i = 0; i < num_rows; ++i) {
- col_base[i] = row_base[i * row_size];
+ col_base[i] = *(rows.fixed_length_rows(start_row + i) +
offset_within_row);
}
break;
case 2:
for (uint32_t i = 0; i < num_rows; ++i) {
- reinterpret_cast<uint16_t*>(col_base)[i] =
- *reinterpret_cast<const uint16_t*>(row_base + i * row_size);
+ reinterpret_cast<uint16_t*>(col_base)[i] = *reinterpret_cast<const
uint16_t*>(
+ rows.fixed_length_rows(start_row + i) + offset_within_row);
}
break;
case 4:
for (uint32_t i = 0; i < num_rows; ++i) {
- reinterpret_cast<uint32_t*>(col_base)[i] =
- *reinterpret_cast<const uint32_t*>(row_base + i * row_size);
+ reinterpret_cast<uint32_t*>(col_base)[i] = *reinterpret_cast<const
uint32_t*>(
+ rows.fixed_length_rows(start_row + i + offset_within_row));
}
break;
case 8:
for (uint32_t i = 0; i < num_rows; ++i) {
- reinterpret_cast<uint64_t*>(col_base)[i] =
- *reinterpret_cast<const uint64_t*>(row_base + i * row_size);
+ reinterpret_cast<uint64_t*>(col_base)[i] = *reinterpret_cast<const
uint64_t*>(
+ rows.fixed_length_rows(start_row + i) + offset_within_row);
}
break;
default:
@@ -297,7 +293,7 @@ void EncoderInteger::Decode(uint32_t start_row, uint32_t
num_rows,
}
} else {
const RowTableImpl::offset_type* row_offsets = rows.offsets() + start_row;
- const uint8_t* row_base = rows.data(2);
+ const uint8_t* row_base = rows.var_length_rows();
row_base += offset_within_row;
uint8_t* col_base = col_prep.mutable_data(1);
switch (col_prep.metadata().fixed_length) {
@@ -343,14 +339,14 @@ void EncoderBinary::EncodeSelectedImp(uint32_t
offset_within_row, RowTableImpl*
if (is_fixed_length) {
uint32_t row_width = rows->metadata().fixed_length;
const uint8_t* src_base = col.data(1);
- uint8_t* dst = rows->mutable_data(1) + offset_within_row;
+ uint8_t* dst = rows->mutable_fixed_length_rows(/*row_id=*/0) +
offset_within_row;
for (uint32_t i = 0; i < num_selected; ++i) {
copy_fn(dst, src_base, selection[i]);
dst += row_width;
}
if (col.data(0)) {
const uint8_t* non_null_bits = col.data(0);
- uint8_t* dst = rows->mutable_data(1) + offset_within_row;
+ dst = rows->mutable_fixed_length_rows(/*row_id=*/0) + offset_within_row;
for (uint32_t i = 0; i < num_selected; ++i) {
bool is_null = !bit_util::GetBit(non_null_bits, selection[i] +
col.bit_offset(0));
if (is_null) {
@@ -361,14 +357,14 @@ void EncoderBinary::EncodeSelectedImp(uint32_t
offset_within_row, RowTableImpl*
}
} else {
const uint8_t* src_base = col.data(1);
- uint8_t* dst = rows->mutable_data(2) + offset_within_row;
+ uint8_t* dst = rows->mutable_var_length_rows() + offset_within_row;
const RowTableImpl::offset_type* offsets = rows->offsets();
for (uint32_t i = 0; i < num_selected; ++i) {
copy_fn(dst + offsets[i], src_base, selection[i]);
}
if (col.data(0)) {
const uint8_t* non_null_bits = col.data(0);
- uint8_t* dst = rows->mutable_data(2) + offset_within_row;
+ uint8_t* dst = rows->mutable_var_length_rows() + offset_within_row;
const RowTableImpl::offset_type* offsets = rows->offsets();
for (uint32_t i = 0; i < num_selected; ++i) {
bool is_null = !bit_util::GetBit(non_null_bits, selection[i] +
col.bit_offset(0));
@@ -584,16 +580,13 @@ void EncoderBinaryPair::DecodeImp(uint32_t
num_rows_to_skip, uint32_t start_row,
uint8_t* dst_A = col1->mutable_data(1);
uint8_t* dst_B = col2->mutable_data(1);
- uint32_t fixed_length = rows.metadata().fixed_length;
const RowTableImpl::offset_type* offsets;
const uint8_t* src_base;
if (is_row_fixed_length) {
- src_base = rows.data(1) +
- static_cast<RowTableImpl::offset_type>(start_row) *
fixed_length +
- offset_within_row;
+ src_base = rows.fixed_length_rows(start_row) + offset_within_row;
offsets = nullptr;
} else {
- src_base = rows.data(2) + offset_within_row;
+ src_base = rows.var_length_rows() + offset_within_row;
offsets = rows.offsets() + start_row;
}
@@ -601,6 +594,7 @@ void EncoderBinaryPair::DecodeImp(uint32_t
num_rows_to_skip, uint32_t start_row,
using col2_type_const = typename std::add_const<col2_type>::type;
if (is_row_fixed_length) {
+ uint32_t fixed_length = rows.metadata().fixed_length;
const uint8_t* src = src_base + num_rows_to_skip * fixed_length;
for (uint32_t i = num_rows_to_skip; i < num_rows; ++i) {
reinterpret_cast<col1_type*>(dst_A)[i] =
*reinterpret_cast<col1_type_const*>(src);
@@ -654,7 +648,7 @@ void EncoderOffsets::Decode(uint32_t start_row, uint32_t
num_rows,
for (uint32_t i = 0; i < num_rows; ++i) {
// Find the beginning of cumulative lengths array for next row
- const uint8_t* row = rows.data(2) + row_offsets[i];
+ const uint8_t* row = rows.var_length_rows() + row_offsets[i];
const uint32_t* varbinary_ends = rows.metadata().varbinary_end_array(row);
// Update the offset of each column
@@ -728,7 +722,7 @@ void EncoderOffsets::EncodeSelectedImp(uint32_t ivarbinary,
RowTableImpl* rows,
const std::vector<KeyColumnArray>& cols,
uint32_t num_selected, const uint16_t*
selection) {
const RowTableImpl::offset_type* row_offsets = rows->offsets();
- uint8_t* row_base = rows->mutable_data(2) +
+ uint8_t* row_base = rows->mutable_var_length_rows() +
rows->metadata().varbinary_end_array_offset +
ivarbinary * sizeof(uint32_t);
const uint32_t* col_offsets = cols[ivarbinary].offsets();
@@ -824,8 +818,6 @@ void EncoderNulls::Decode(uint32_t start_row, uint32_t
num_rows, const RowTableI
DCHECK(col.mutable_data(0) || col.metadata().is_null_type);
}
- const uint8_t* null_masks = rows.null_masks();
- uint32_t null_masks_bytes_per_row = rows.metadata().null_masks_bytes_per_row;
for (size_t col = 0; col < cols->size(); ++col) {
if ((*cols)[col].metadata().is_null_type) {
continue;
@@ -839,9 +831,7 @@ void EncoderNulls::Decode(uint32_t start_row, uint32_t
num_rows, const RowTableI
memset(non_nulls + 1, 0xff, bit_util::BytesForBits(num_rows -
bits_in_first_byte));
}
for (uint32_t row = 0; row < num_rows; ++row) {
- uint32_t null_masks_bit_id =
- (start_row + row) * null_masks_bytes_per_row * 8 +
static_cast<uint32_t>(col);
- bool is_set = bit_util::GetBit(null_masks, null_masks_bit_id);
+ bool is_set = rows.is_null(start_row + row, static_cast<uint32_t>(col));
if (is_set) {
bit_util::ClearBit(non_nulls, bit_offset + row);
}
@@ -853,7 +843,7 @@ void EncoderVarBinary::EncodeSelected(uint32_t ivarbinary,
RowTableImpl* rows,
const KeyColumnArray& cols, uint32_t
num_selected,
const uint16_t* selection) {
const RowTableImpl::offset_type* row_offsets = rows->offsets();
- uint8_t* row_base = rows->mutable_data(2);
+ uint8_t* row_base = rows->mutable_var_length_rows();
const uint32_t* col_offsets = cols.offsets();
const uint8_t* col_base = cols.data(2);
@@ -882,7 +872,7 @@ void EncoderVarBinary::EncodeSelected(uint32_t ivarbinary,
RowTableImpl* rows,
void EncoderNulls::EncodeSelected(RowTableImpl* rows,
const std::vector<KeyColumnArray>& cols,
uint32_t num_selected, const uint16_t*
selection) {
- uint8_t* null_masks = rows->null_masks();
+ uint8_t* null_masks = rows->mutable_null_masks(/*row_id=*/0);
uint32_t null_mask_num_bytes = rows->metadata().null_masks_bytes_per_row;
memset(null_masks, 0, null_mask_num_bytes * num_selected);
for (size_t icol = 0; icol < cols.size(); ++icol) {
diff --git a/cpp/src/arrow/compute/row/encode_internal.h
b/cpp/src/arrow/compute/row/encode_internal.h
index 37538fcc4b..5ad82e0c8e 100644
--- a/cpp/src/arrow/compute/row/encode_internal.h
+++ b/cpp/src/arrow/compute/row/encode_internal.h
@@ -164,11 +164,10 @@ class EncoderBinary {
uint32_t col_width = col_const->metadata().fixed_length;
if (is_row_fixed_length) {
- uint32_t row_width = rows_const->metadata().fixed_length;
for (uint32_t i = 0; i < num_rows; ++i) {
const uint8_t* src;
uint8_t* dst;
- src = rows_const->data(1) + row_width * (start_row + i) +
offset_within_row;
+ src = rows_const->fixed_length_rows(start_row + i) + offset_within_row;
dst = col_mutable_maybe_null->mutable_data(1) + col_width * i;
copy_fn(dst, src, col_width);
}
@@ -177,7 +176,8 @@ class EncoderBinary {
for (uint32_t i = 0; i < num_rows; ++i) {
const uint8_t* src;
uint8_t* dst;
- src = rows_const->data(2) + row_offsets[start_row + i] +
offset_within_row;
+ src = rows_const->var_length_rows() + row_offsets[start_row + i] +
+ offset_within_row;
dst = col_mutable_maybe_null->mutable_data(1) + col_width * i;
copy_fn(dst, src, col_width);
}
@@ -277,7 +277,7 @@ class EncoderVarBinary {
col_offset_next = col_offsets[i + 1];
RowTableImpl::offset_type row_offset = row_offsets_for_batch[i];
- const uint8_t* row = rows_const->data(2) + row_offset;
+ const uint8_t* row = rows_const->var_length_rows() + row_offset;
uint32_t offset_within_row;
uint32_t length;
@@ -293,7 +293,7 @@ class EncoderVarBinary {
const uint8_t* src;
uint8_t* dst;
- src = rows_const->data(2) + row_offset;
+ src = rows_const->var_length_rows() + row_offset;
dst = col_mutable_maybe_null->mutable_data(2) + col_offset;
copy_fn(dst, src, length);
}
diff --git a/cpp/src/arrow/compute/row/encode_internal_avx2.cc
b/cpp/src/arrow/compute/row/encode_internal_avx2.cc
index d2e317deb8..650d24b8ef 100644
--- a/cpp/src/arrow/compute/row/encode_internal_avx2.cc
+++ b/cpp/src/arrow/compute/row/encode_internal_avx2.cc
@@ -75,14 +75,9 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t
start_row, uint32_t num_rows
uint32_t fixed_length = rows.metadata().fixed_length;
const RowTableImpl::offset_type* offsets;
- const uint8_t* src_base;
if (is_row_fixed_length) {
- src_base = rows.data(1) +
- static_cast<RowTableImpl::offset_type>(fixed_length) *
start_row +
- offset_within_row;
offsets = nullptr;
} else {
- src_base = rows.data(2) + offset_within_row;
offsets = rows.offsets() + start_row;
}
@@ -94,14 +89,15 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t
start_row, uint32_t num_rows
for (uint32_t i = 0; i < num_rows / unroll; ++i) {
const __m128i *src0, *src1, *src2, *src3;
if (is_row_fixed_length) {
- const uint8_t* src = src_base + (i * unroll) * fixed_length;
+ const uint8_t* src =
+ rows.fixed_length_rows(start_row + i * unroll) + offset_within_row;
src0 = reinterpret_cast<const __m128i*>(src);
src1 = reinterpret_cast<const __m128i*>(src + fixed_length);
src2 = reinterpret_cast<const __m128i*>(src + fixed_length * 2);
src3 = reinterpret_cast<const __m128i*>(src + fixed_length * 3);
} else {
+ const uint8_t* src = rows.var_length_rows() + offset_within_row;
const RowTableImpl::offset_type* row_offsets = offsets + i * unroll;
- const uint8_t* src = src_base;
src0 = reinterpret_cast<const __m128i*>(src + row_offsets[0]);
src1 = reinterpret_cast<const __m128i*>(src + row_offsets[1]);
src2 = reinterpret_cast<const __m128i*>(src + row_offsets[2]);
@@ -127,7 +123,8 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t
start_row, uint32_t num_rows
uint8_t buffer[64];
for (uint32_t i = 0; i < num_rows / unroll; ++i) {
if (is_row_fixed_length) {
- const uint8_t* src = src_base + (i * unroll) * fixed_length;
+ const uint8_t* src =
+ rows.fixed_length_rows(start_row + i * unroll) + offset_within_row;
for (int j = 0; j < unroll; ++j) {
if (col_width == 1) {
reinterpret_cast<uint16_t*>(buffer)[j] =
@@ -141,8 +138,8 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t
start_row, uint32_t num_rows
}
}
} else {
+ const uint8_t* src = rows.var_length_rows() + offset_within_row;
const RowTableImpl::offset_type* row_offsets = offsets + i * unroll;
- const uint8_t* src = src_base;
for (int j = 0; j < unroll; ++j) {
if (col_width == 1) {
reinterpret_cast<uint16_t*>(buffer)[j] =
diff --git a/cpp/src/arrow/compute/row/row_internal.cc
b/cpp/src/arrow/compute/row/row_internal.cc
index aa7e62add4..492cc71ac4 100644
--- a/cpp/src/arrow/compute/row/row_internal.cc
+++ b/cpp/src/arrow/compute/row/row_internal.cc
@@ -406,10 +406,14 @@ bool RowTableImpl::has_any_nulls(const LightContext* ctx)
const {
return true;
}
if (num_rows_for_has_any_nulls_ < num_rows_) {
- auto size_per_row = metadata().null_masks_bytes_per_row;
+ DCHECK_LE(num_rows_for_has_any_nulls_,
std::numeric_limits<uint32_t>::max());
+ int64_t num_bytes =
+ metadata().null_masks_bytes_per_row * (num_rows_ -
num_rows_for_has_any_nulls_);
+ DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max());
has_any_nulls_ = !util::bit_util::are_all_bytes_zero(
- ctx->hardware_flags, null_masks() + size_per_row *
num_rows_for_has_any_nulls_,
- static_cast<uint32_t>(size_per_row * (num_rows_ -
num_rows_for_has_any_nulls_)));
+ ctx->hardware_flags,
+ null_masks(static_cast<uint32_t>(num_rows_for_has_any_nulls_)),
+ static_cast<uint32_t>(num_bytes));
num_rows_for_has_any_nulls_ = num_rows_;
}
return has_any_nulls_;
diff --git a/cpp/src/arrow/compute/row/row_internal.h
b/cpp/src/arrow/compute/row/row_internal.h
index 3ab86fd1fc..0919773a22 100644
--- a/cpp/src/arrow/compute/row/row_internal.h
+++ b/cpp/src/arrow/compute/row/row_internal.h
@@ -199,29 +199,44 @@ class ARROW_EXPORT RowTableImpl {
const RowTableMetadata& metadata() const { return metadata_; }
/// \brief The number of rows stored in the table
int64_t length() const { return num_rows_; }
- // Accessors into the table's buffers
- const uint8_t* data(int i) const {
- ARROW_DCHECK(i >= 0 && i < kMaxBuffers);
- if (ARROW_PREDICT_TRUE(buffers_[i])) {
- return buffers_[i]->data();
- }
- return NULLPTR;
+
+ const uint8_t* null_masks(uint32_t row_id) const {
+ return data(0) + static_cast<int64_t>(row_id) *
metadata_.null_masks_bytes_per_row;
}
- uint8_t* mutable_data(int i) {
- ARROW_DCHECK(i >= 0 && i < kMaxBuffers);
- if (ARROW_PREDICT_TRUE(buffers_[i])) {
- return buffers_[i]->mutable_data();
- }
- return NULLPTR;
+ uint8_t* mutable_null_masks(uint32_t row_id) {
+ return mutable_data(0) +
+ static_cast<int64_t>(row_id) * metadata_.null_masks_bytes_per_row;
+ }
+ bool is_null(uint32_t row_id, uint32_t col_pos) const {
+ return bit_util::GetBit(null_masks(row_id), col_pos);
}
+
+ const uint8_t* fixed_length_rows(uint32_t row_id) const {
+ ARROW_DCHECK(metadata_.is_fixed_length);
+ return data(1) + static_cast<int64_t>(row_id) * metadata_.fixed_length;
+ }
+ uint8_t* mutable_fixed_length_rows(uint32_t row_id) {
+ ARROW_DCHECK(metadata_.is_fixed_length);
+ return mutable_data(1) + static_cast<int64_t>(row_id) *
metadata_.fixed_length;
+ }
+
const offset_type* offsets() const {
+ ARROW_DCHECK(!metadata_.is_fixed_length);
return reinterpret_cast<const offset_type*>(data(1));
}
offset_type* mutable_offsets() {
+ ARROW_DCHECK(!metadata_.is_fixed_length);
return reinterpret_cast<offset_type*>(mutable_data(1));
}
- const uint8_t* null_masks() const { return null_masks_->data(); }
- uint8_t* null_masks() { return null_masks_->mutable_data(); }
+
+ const uint8_t* var_length_rows() const {
+ ARROW_DCHECK(!metadata_.is_fixed_length);
+ return data(2);
+ }
+ uint8_t* mutable_var_length_rows() {
+ ARROW_DCHECK(!metadata_.is_fixed_length);
+ return mutable_data(2);
+ }
/// \brief True if there is a null value anywhere in the table
///
@@ -237,6 +252,22 @@ class ARROW_EXPORT RowTableImpl {
}
private:
+ // Accessors into the table's buffers
+ const uint8_t* data(int i) const {
+ ARROW_DCHECK(i >= 0 && i < kMaxBuffers);
+ if (ARROW_PREDICT_TRUE(buffers_[i])) {
+ return buffers_[i]->data();
+ }
+ return NULLPTR;
+ }
+ uint8_t* mutable_data(int i) {
+ ARROW_DCHECK(i >= 0 && i < kMaxBuffers);
+ if (ARROW_PREDICT_TRUE(buffers_[i])) {
+ return buffers_[i]->mutable_data();
+ }
+ return NULLPTR;
+ }
+
/// \brief Resize the fixed length buffers to store `num_extra_rows` more
rows. The
/// fixed length buffers are buffers_[0] for null masks, buffers_[1] for row
data if the
/// row is fixed length, or for row offsets otherwise.
diff --git a/cpp/src/arrow/compute/row/row_test.cc
b/cpp/src/arrow/compute/row/row_test.cc
index 5057ce91b5..49d8f2a9af 100644
--- a/cpp/src/arrow/compute/row/row_test.cc
+++ b/cpp/src/arrow/compute/row/row_test.cc
@@ -92,9 +92,8 @@ TEST(RowTableMemoryConsumption, Encode) {
ASSERT_OK_AND_ASSIGN(auto row_table,
MakeRowTableFromColumn(col, num_rows,
dt->byte_width(),
/*string_alignment=*/0));
- ASSERT_NE(row_table.data(0), NULLPTR);
- ASSERT_NE(row_table.data(1), NULLPTR);
- ASSERT_EQ(row_table.data(2), NULLPTR);
+ ASSERT_NE(row_table.null_masks(/*row_id=*/0), NULLPTR);
+ ASSERT_NE(row_table.fixed_length_rows(/*row_id=*/0), NULLPTR);
int64_t actual_null_mask_size =
num_rows * row_table.metadata().null_masks_bytes_per_row;
@@ -113,9 +112,9 @@ TEST(RowTableMemoryConsumption, Encode) {
SCOPED_TRACE("encoding var length column of " + std::to_string(num_rows)
+ " rows");
ASSERT_OK_AND_ASSIGN(auto row_table,
MakeRowTableFromColumn(var_length_column, num_rows,
4, 4));
- ASSERT_NE(row_table.data(0), NULLPTR);
- ASSERT_NE(row_table.data(1), NULLPTR);
- ASSERT_NE(row_table.data(2), NULLPTR);
+ ASSERT_NE(row_table.null_masks(/*row_id=*/0), NULLPTR);
+ ASSERT_NE(row_table.offsets(), NULLPTR);
+ ASSERT_NE(row_table.var_length_rows(), NULLPTR);
int64_t actual_null_mask_size =
num_rows * row_table.metadata().null_masks_bytes_per_row;
diff --git a/cpp/src/arrow/compute/row/row_util_avx2_internal.h
b/cpp/src/arrow/compute/row/row_util_avx2_internal.h
new file mode 100644
index 0000000000..a8fce7e0e8
--- /dev/null
+++ b/cpp/src/arrow/compute/row/row_util_avx2_internal.h
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+#pragma once
+
+#include "arrow/compute/row/row_internal.h"
+#include "arrow/util/simd.h"
+
+#if !defined(ARROW_HAVE_AVX2) && !defined(ARROW_HAVE_AVX512) && \
+ !defined(ARROW_HAVE_RUNTIME_AVX2) && !defined(ARROW_HAVE_RUNTIME_AVX512)
+# error "This file should only be included when AVX2 or AVX512 is enabled"
+#endif
+
+namespace arrow::compute {
+
+// Convert 8 64-bit comparision results, each being 0 or -1, to 8 bytes.
+inline uint64_t Cmp64To8(__m256i cmp64_lo, __m256i cmp64_hi) {
+ uint32_t cmp_lo = _mm256_movemask_epi8(cmp64_lo);
+ uint32_t cmp_hi = _mm256_movemask_epi8(cmp64_hi);
+ return cmp_lo | (static_cast<uint64_t>(cmp_hi) << 32);
+}
+
+// Convert 8 32-bit comparision results, each being 0 or -1, to 8 bytes.
+inline uint64_t Cmp32To8(__m256i cmp32) {
+ return Cmp64To8(_mm256_cvtepi32_epi64(_mm256_castsi256_si128(cmp32)),
+ _mm256_cvtepi32_epi64(_mm256_extracti128_si256(cmp32, 1)));
+}
+
+// Get null bits for 8 32-bit row ids in `row_id32` at `col_pos` as a vector
of 32-bit
+// integers. Note that the result integer is 0 if the corresponding column is
not null, or
+// 1 otherwise.
+inline __m256i GetNullBitInt32(const RowTableImpl& rows, uint32_t col_pos,
+ __m256i row_id32) {
+ const uint8_t* null_masks = rows.null_masks(/*row_id=*/0);
+ __m256i null_mask_num_bits =
+ _mm256_set1_epi64x(rows.metadata().null_masks_bytes_per_row * 8);
+ __m256i row_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(row_id32));
+ __m256i row_hi = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(row_id32,
1));
+ __m256i bit_id_lo = _mm256_mul_epi32(row_lo, null_mask_num_bits);
+ __m256i bit_id_hi = _mm256_mul_epi32(row_hi, null_mask_num_bits);
+ bit_id_lo = _mm256_add_epi64(bit_id_lo, _mm256_set1_epi64x(col_pos));
+ bit_id_hi = _mm256_add_epi64(bit_id_hi, _mm256_set1_epi64x(col_pos));
+ __m128i right_lo = _mm256_i64gather_epi32(reinterpret_cast<const
int*>(null_masks),
+ _mm256_srli_epi64(bit_id_lo, 3),
1);
+ __m128i right_hi = _mm256_i64gather_epi32(reinterpret_cast<const
int*>(null_masks),
+ _mm256_srli_epi64(bit_id_hi, 3),
1);
+ __m256i right = _mm256_set_m128i(right_hi, right_lo);
+ return _mm256_and_si256(_mm256_set1_epi32(1), _mm256_srli_epi32(right,
col_pos & 7));
+}
+
+} // namespace arrow::compute