This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 3d6d581731 GH-44052: [C++][Compute] Reduce the complexity of row
segmenter (#44053)
3d6d581731 is described below
commit 3d6d5817313920abc71c854828d95b63b2562938
Author: Rossi Sun <[email protected]>
AuthorDate: Wed Sep 18 15:51:03 2024 +0800
GH-44052: [C++][Compute] Reduce the complexity of row segmenter (#44053)
### Rationale for this change
As described in #44052, currently `AnyKeysSegmenter::GetNextSegment` has
`O(n*m)` complexity, where `n` is the number of rows in a batch, and `m` is the
number of segments in this batch (a "segment" is the group of contiguous rows
who have the same segment key). This is because in each invocation of the
method, it computes all the group ids of the remaining rows in this batch,
where it's only interested in the first group, making the rest of the
computation a waste.
In this PR I introduced a new API `GetSegments` (and subsequently
deprecated the old `GetNextSegment`) to compute the group ids only once and
iterate all the segments outside to avoid the duplicated computation. This
reduces the complexity from `O(n*m)` to `O(n)`.
### What changes are included in this PR?
1. Because `grouper.h` is a [public
header](https://github.com/apache/arrow/blob/8556001e6a8b4c7f35d4e18c28704d7811005904/cpp/src/arrow/compute/api.h#L47),
so I assume `RowSegmenter::GetNextSegment` is a public API and only deprecate
it instead of removing it.
2. Implement new API `RowSegmenter::GetSegments` and update the call-sites.
3. Some code reorg of the segmenter code (mostly moving to inside a class).
4. A new benchmark for the segmented aggregation. (The benchmark result is
listed in the comments below, which shows up to `50x` speedup, nearly `O(n*m)`
to `O(n)` complexity reduction.)
### Are these changes tested?
Legacy tests are sufficient.
### Are there any user-facing changes?
Yes.
**This PR includes breaking changes to public APIs.**
The API `RowSegmenter::GetNextSegment` is deprecated due to its
inefficiency and replaced with a more efficient one `RowSegmenter::GetSegments`.
* GitHub Issue: #44052
Lead-authored-by: Ruoxi Sun <[email protected]>
Co-authored-by: Rossi Sun <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/acero/aggregate_benchmark.cc | 119 ++++++++++---
cpp/src/arrow/acero/aggregate_internal.h | 9 +-
cpp/src/arrow/acero/hash_aggregate_test.cc | 66 +++-----
cpp/src/arrow/compute/row/grouper.cc | 263 ++++++++++++++++++++++-------
cpp/src/arrow/compute/row/grouper.h | 5 +
5 files changed, 332 insertions(+), 130 deletions(-)
diff --git a/cpp/src/arrow/acero/aggregate_benchmark.cc
b/cpp/src/arrow/acero/aggregate_benchmark.cc
index c0dfba6633..9c90b63904 100644
--- a/cpp/src/arrow/acero/aggregate_benchmark.cc
+++ b/cpp/src/arrow/acero/aggregate_benchmark.cc
@@ -24,6 +24,7 @@
#include "arrow/array/array_primitive.h"
#include "arrow/compute/api.h"
#include "arrow/table.h"
+#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/util/benchmark_util.h"
@@ -325,7 +326,8 @@ BENCHMARK_TEMPLATE(ReferenceSum,
SumBitmapVectorizeUnroll<int64_t>)
std::shared_ptr<RecordBatch> RecordBatchFromArrays(
const std::vector<std::shared_ptr<Array>>& arguments,
- const std::vector<std::shared_ptr<Array>>& keys) {
+ const std::vector<std::shared_ptr<Array>>& keys,
+ const std::vector<std::shared_ptr<Array>>& segment_keys) {
std::vector<std::shared_ptr<Field>> fields;
std::vector<std::shared_ptr<Array>> all_arrays;
int64_t length = -1;
@@ -347,37 +349,56 @@ std::shared_ptr<RecordBatch> RecordBatchFromArrays(
fields.push_back(field("key" + ToChars(key_idx), key->type()));
all_arrays.push_back(key);
}
+ for (std::size_t segment_key_idx = 0; segment_key_idx < segment_keys.size();
+ segment_key_idx++) {
+ const auto& segment_key = segment_keys[segment_key_idx];
+ DCHECK_EQ(segment_key->length(), length);
+ fields.push_back(
+ field("segment_key" + ToChars(segment_key_idx), segment_key->type()));
+ all_arrays.push_back(segment_key);
+ }
return RecordBatch::Make(schema(std::move(fields)), length,
std::move(all_arrays));
}
Result<std::shared_ptr<Table>> BatchGroupBy(
std::shared_ptr<RecordBatch> batch, std::vector<Aggregate> aggregates,
- std::vector<FieldRef> keys, bool use_threads = false,
- MemoryPool* memory_pool = default_memory_pool()) {
+ std::vector<FieldRef> keys, std::vector<FieldRef> segment_keys,
+ bool use_threads = false, MemoryPool* memory_pool = default_memory_pool())
{
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> table,
Table::FromRecordBatches({std::move(batch)}));
Declaration plan = Declaration::Sequence(
{{"table_source", TableSourceNodeOptions(std::move(table))},
- {"aggregate", AggregateNodeOptions(std::move(aggregates),
std::move(keys))}});
+ {"aggregate", AggregateNodeOptions(std::move(aggregates),
std::move(keys),
+ std::move(segment_keys))}});
return DeclarationToTable(std::move(plan), use_threads, memory_pool);
}
-static void BenchmarkGroupBy(benchmark::State& state, std::vector<Aggregate>
aggregates,
- const std::vector<std::shared_ptr<Array>>&
arguments,
- const std::vector<std::shared_ptr<Array>>& keys) {
- std::shared_ptr<RecordBatch> batch = RecordBatchFromArrays(arguments, keys);
+static void BenchmarkAggregate(
+ benchmark::State& state, std::vector<Aggregate> aggregates,
+ const std::vector<std::shared_ptr<Array>>& arguments,
+ const std::vector<std::shared_ptr<Array>>& keys,
+ const std::vector<std::shared_ptr<Array>>& segment_keys = {}) {
+ std::shared_ptr<RecordBatch> batch =
+ RecordBatchFromArrays(arguments, keys, segment_keys);
std::vector<FieldRef> key_refs;
for (std::size_t key_idx = 0; key_idx < keys.size(); key_idx++) {
key_refs.emplace_back(static_cast<int>(key_idx + arguments.size()));
}
+ std::vector<FieldRef> segment_key_refs;
+ for (std::size_t segment_key_idx = 0; segment_key_idx < segment_keys.size();
+ segment_key_idx++) {
+ segment_key_refs.emplace_back(
+ static_cast<int>(segment_key_idx + arguments.size() + keys.size()));
+ }
for (std::size_t arg_idx = 0; arg_idx < arguments.size(); arg_idx++) {
aggregates[arg_idx].target = {FieldRef(static_cast<int>(arg_idx))};
}
int64_t total_bytes = TotalBufferSize(*batch);
for (auto _ : state) {
- ABORT_NOT_OK(BatchGroupBy(batch, aggregates, key_refs));
+ ABORT_NOT_OK(BatchGroupBy(batch, aggregates, key_refs, segment_key_refs));
}
state.SetBytesProcessed(total_bytes * state.iterations());
+ state.SetItemsProcessed(batch->num_rows() * state.iterations());
}
#define GROUP_BY_BENCHMARK(Name, Impl) \
@@ -404,7 +425,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyStringSet, [&] {
/*min_length=*/3,
/*max_length=*/32);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});
GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallStringSet, [&] {
@@ -419,7 +440,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallStringSet, [&] {
/*min_length=*/3,
/*max_length=*/32);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});
GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumStringSet, [&] {
@@ -434,7 +455,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumStringSet, [&] {
/*min_length=*/3,
/*max_length=*/32);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});
GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntegerSet, [&] {
@@ -448,7 +469,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntegerSet, [&] {
/*min=*/0,
/*max=*/15);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});
GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntegerSet, [&] {
@@ -462,7 +483,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntegerSet, [&] {
/*min=*/0,
/*max=*/255);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});
GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntegerSet, [&] {
@@ -476,7 +497,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntegerSet, [&]
{
/*min=*/0,
/*max=*/4095);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {key});
});
GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntStringPairSet, [&] {
@@ -494,7 +515,7 @@ GROUP_BY_BENCHMARK(SumDoublesGroupedByTinyIntStringPairSet,
[&] {
/*min_length=*/3,
/*max_length=*/32);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
});
GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntStringPairSet, [&] {
@@ -512,7 +533,7 @@
GROUP_BY_BENCHMARK(SumDoublesGroupedBySmallIntStringPairSet, [&] {
/*min_length=*/3,
/*max_length=*/32);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
});
GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntStringPairSet, [&] {
@@ -530,7 +551,7 @@
GROUP_BY_BENCHMARK(SumDoublesGroupedByMediumIntStringPairSet, [&] {
/*min_length=*/3,
/*max_length=*/32);
- BenchmarkGroupBy(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
+ BenchmarkAggregate(state, {{"hash_sum", ""}}, {summand}, {int_key, str_key});
});
// Grouped MinMax
@@ -543,7 +564,7 @@ GROUP_BY_BENCHMARK(MinMaxDoublesGroupedByMediumInt, [&] {
/*nan_probability=*/args.null_proportion / 10);
auto int_key = rng.Int64(args.size, /*min=*/0, /*max=*/63);
- BenchmarkGroupBy(state, {{"hash_min_max", ""}}, {input}, {int_key});
+ BenchmarkAggregate(state, {{"hash_min_max", ""}}, {input}, {int_key});
});
GROUP_BY_BENCHMARK(MinMaxShortStringsGroupedByMediumInt, [&] {
@@ -553,7 +574,7 @@ GROUP_BY_BENCHMARK(MinMaxShortStringsGroupedByMediumInt,
[&] {
/*null_probability=*/args.null_proportion);
auto int_key = rng.Int64(args.size, /*min=*/0, /*max=*/63);
- BenchmarkGroupBy(state, {{"hash_min_max", ""}}, {input}, {int_key});
+ BenchmarkAggregate(state, {{"hash_min_max", ""}}, {input}, {int_key});
});
GROUP_BY_BENCHMARK(MinMaxLongStringsGroupedByMediumInt, [&] {
@@ -563,7 +584,7 @@ GROUP_BY_BENCHMARK(MinMaxLongStringsGroupedByMediumInt, [&]
{
/*null_probability=*/args.null_proportion);
auto int_key = rng.Int64(args.size, /*min=*/0, /*max=*/63);
- BenchmarkGroupBy(state, {{"hash_min_max", ""}}, {input}, {int_key});
+ BenchmarkAggregate(state, {{"hash_min_max", ""}}, {input}, {int_key});
});
//
@@ -866,5 +887,61 @@
BENCHMARK(TDigestKernelDoubleMedian)->Apply(QuantileKernelArgs);
BENCHMARK(TDigestKernelDoubleDeciles)->Apply(QuantileKernelArgs);
BENCHMARK(TDigestKernelDoubleCentiles)->Apply(QuantileKernelArgs);
+//
+// Segmented Aggregate
+//
+
+static void BenchmarkSegmentedAggregate(
+ benchmark::State& state, int64_t num_rows, std::vector<Aggregate>
aggregates,
+ const std::vector<std::shared_ptr<Array>>& arguments,
+ const std::vector<std::shared_ptr<Array>>& keys, int64_t num_segment_keys,
+ int64_t num_segments) {
+ ASSERT_GT(num_segments, 0);
+
+ auto rng = random::RandomArrayGenerator(42);
+ auto segment_key = rng.Int64(num_rows, /*min=*/0, /*max=*/num_segments - 1);
+ int64_t* values = segment_key->data()->GetMutableValues<int64_t>(1);
+ std::sort(values, values + num_rows);
+ // num_segment_keys copies of the segment key.
+ ArrayVector segment_keys(num_segment_keys, segment_key);
+
+ BenchmarkAggregate(state, std::move(aggregates), arguments, keys,
segment_keys);
+}
+
+template <typename... Args>
+static void CountScalarSegmentedByInts(benchmark::State& state, Args&&...) {
+ constexpr int64_t num_rows = 32 * 1024;
+
+ // A trivial column to count from.
+ auto arg = ConstantArrayGenerator::Zeroes(num_rows, int32());
+
+ BenchmarkSegmentedAggregate(state, num_rows, {{"count", ""}}, {arg},
/*keys=*/{},
+ state.range(0), state.range(1));
+}
+BENCHMARK(CountScalarSegmentedByInts)
+ ->ArgNames({"SegmentKeys", "Segments"})
+ ->ArgsProduct({{0, 1, 2}, benchmark::CreateRange(1, 256, 8)});
+
+template <typename... Args>
+static void CountGroupByIntsSegmentedByInts(benchmark::State& state,
Args&&...) {
+ constexpr int64_t num_rows = 32 * 1024;
+
+ // A trivial column to count from.
+ auto arg = ConstantArrayGenerator::Zeroes(num_rows, int32());
+
+ auto rng = random::RandomArrayGenerator(42);
+ int64_t num_keys = state.range(0);
+ ArrayVector keys(num_keys);
+ for (auto& key : keys) {
+ key = rng.Int64(num_rows, /*min=*/0, /*max=*/64);
+ }
+
+ BenchmarkSegmentedAggregate(state, num_rows, {{"hash_count", ""}}, {arg},
keys,
+ state.range(1), state.range(2));
+}
+BENCHMARK(CountGroupByIntsSegmentedByInts)
+ ->ArgNames({"Keys", "SegmentKeys", "Segments"})
+ ->ArgsProduct({{1, 2}, {0, 1, 2}, benchmark::CreateRange(1, 256, 8)});
+
} // namespace acero
} // namespace arrow
diff --git a/cpp/src/arrow/acero/aggregate_internal.h
b/cpp/src/arrow/acero/aggregate_internal.h
index 5730d99f93..7cdc424cbb 100644
--- a/cpp/src/arrow/acero/aggregate_internal.h
+++ b/cpp/src/arrow/acero/aggregate_internal.h
@@ -131,17 +131,14 @@ void AggregatesToString(std::stringstream* ss, const
Schema& input_schema,
template <typename BatchHandler>
Status HandleSegments(RowSegmenter* segmenter, const ExecBatch& batch,
const std::vector<int>& ids, const BatchHandler&
handle_batch) {
- int64_t offset = 0;
ARROW_ASSIGN_OR_RAISE(auto segment_exec_batch, batch.SelectValues(ids));
ExecSpan segment_batch(segment_exec_batch);
- while (true) {
- ARROW_ASSIGN_OR_RAISE(compute::Segment segment,
- segmenter->GetNextSegment(segment_batch, offset));
- if (segment.offset >= segment_batch.length) break; // condition of
no-next-segment
+ ARROW_ASSIGN_OR_RAISE(auto segments, segmenter->GetSegments(segment_batch));
+ for (const auto& segment : segments) {
ARROW_RETURN_NOT_OK(handle_batch(batch, segment));
- offset = segment.offset + segment.length;
}
+
return Status::OK();
}
diff --git a/cpp/src/arrow/acero/hash_aggregate_test.cc
b/cpp/src/arrow/acero/hash_aggregate_test.cc
index 743cb20d19..f76e326cd7 100644
--- a/cpp/src/arrow/acero/hash_aggregate_test.cc
+++ b/cpp/src/arrow/acero/hash_aggregate_test.cc
@@ -585,19 +585,12 @@ void TestGroupClassSupportedKeys(
void TestSegments(std::unique_ptr<RowSegmenter>& segmenter, const ExecSpan&
batch,
std::vector<Segment> expected_segments) {
- int64_t offset = 0, segment_num = 0;
- for (auto expected_segment : expected_segments) {
- SCOPED_TRACE("segment #" + ToChars(segment_num++));
- ASSERT_OK_AND_ASSIGN(auto segment, segmenter->GetNextSegment(batch,
offset));
- ASSERT_EQ(expected_segment, segment);
- offset = segment.offset + segment.length;
+ ASSERT_OK_AND_ASSIGN(auto actual_segments, segmenter->GetSegments(batch));
+ ASSERT_EQ(actual_segments.size(), expected_segments.size());
+ for (size_t i = 0; i < actual_segments.size(); ++i) {
+ SCOPED_TRACE("segment #" + ToChars(i));
+ ASSERT_EQ(actual_segments[i], expected_segments[i]);
}
- // Assert next is the last (empty) segment.
- ASSERT_OK_AND_ASSIGN(auto segment, segmenter->GetNextSegment(batch, offset));
- ASSERT_GE(segment.offset, batch.length);
- ASSERT_EQ(segment.length, 0);
- ASSERT_TRUE(segment.is_open);
- ASSERT_TRUE(segment.extends);
}
Result<std::unique_ptr<Grouper>> MakeGrouper(const std::vector<TypeHolder>&
key_types) {
@@ -629,61 +622,47 @@ TEST(RowSegmenter, Basics) {
auto batch2 = ExecBatchFromJSON(types2, "[[1, 1], [1, 2], [2, 2]]");
auto batch1 = ExecBatchFromJSON(types1, "[[1], [1], [2]]");
ExecBatch batch0({}, 3);
- {
- SCOPED_TRACE("offset");
- ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types0));
- ExecSpan span0(batch0);
- for (int64_t offset : {-1, 4}) {
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
- HasSubstr("invalid grouping segmenter
offset"),
- segmenter->GetNextSegment(span0,
offset));
- }
- }
{
SCOPED_TRACE("types0 segmenting of batch2");
ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types0));
ExecSpan span2(batch2);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 0
"),
- segmenter->GetNextSegment(span2, 0));
+ segmenter->GetSegments(span2));
ExecSpan span0(batch0);
- TestSegments(segmenter, span0, {{0, 3, true, true}, {3, 0, true, true}});
+ TestSegments(segmenter, span0, {{0, 3, true, true}});
}
{
SCOPED_TRACE("bad_types1 segmenting of batch1");
ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(bad_types1));
ExecSpan span1(batch1);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 0
of type "),
- segmenter->GetNextSegment(span1, 0));
+ segmenter->GetSegments(span1));
}
{
SCOPED_TRACE("types1 segmenting of batch2");
ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types1));
ExecSpan span2(batch2);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 1
"),
- segmenter->GetNextSegment(span2, 0));
+ segmenter->GetSegments(span2));
ExecSpan span1(batch1);
- TestSegments(segmenter, span1,
- {{0, 2, false, true}, {2, 1, true, false}, {3, 0, true,
true}});
+ TestSegments(segmenter, span1, {{0, 2, false, true}, {2, 1, true, false}});
}
{
SCOPED_TRACE("bad_types2 segmenting of batch2");
ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(bad_types2));
ExecSpan span2(batch2);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 1
of type "),
- segmenter->GetNextSegment(span2, 0));
+ segmenter->GetSegments(span2));
}
{
SCOPED_TRACE("types2 segmenting of batch1");
ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types2));
ExecSpan span1(batch1);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 2
"),
- segmenter->GetNextSegment(span1, 0));
+ segmenter->GetSegments(span1));
ExecSpan span2(batch2);
TestSegments(segmenter, span2,
- {{0, 1, false, true},
- {1, 1, false, false},
- {2, 1, true, false},
- {3, 0, true, true}});
+ {{0, 1, false, true}, {1, 1, false, false}, {2, 1, true,
false}});
}
}
@@ -696,8 +675,7 @@ TEST(RowSegmenter, NonOrdered) {
{{0, 2, false, true},
{2, 1, false, false},
{3, 1, false, false},
- {4, 1, true, false},
- {5, 0, true, true}});
+ {4, 1, true, false}});
}
{
std::vector<TypeHolder> types = {int32(), int32()};
@@ -707,8 +685,7 @@ TEST(RowSegmenter, NonOrdered) {
{{0, 2, false, true},
{2, 1, false, false},
{3, 1, false, false},
- {4, 1, true, false},
- {5, 0, true, true}});
+ {4, 1, true, false}});
}
}
@@ -767,8 +744,7 @@ TEST(RowSegmenter, MultipleSegments) {
{3, 1, false, false},
{4, 2, false, false},
{6, 2, false, false},
- {8, 1, true, false},
- {9, 0, true, true}});
+ {8, 1, true, false}});
}
{
std::vector<TypeHolder> types = {int32(), int32()};
@@ -782,8 +758,7 @@ TEST(RowSegmenter, MultipleSegments) {
{3, 1, false, false},
{4, 2, false, false},
{6, 2, false, false},
- {8, 1, true, false},
- {9, 0, true, true}});
+ {8, 1, true, false}});
}
}
@@ -845,7 +820,7 @@ void TestRowSegmenterConstantBatch(
std::vector<TypeHolder> key_types(types.begin(), types.begin() + size);
ARROW_ASSIGN_OR_RAISE(auto segmenter, make_segmenter(key_types));
for (size_t i = 0; i < repetitions; i++) {
- TestSegments(segmenter, ExecSpan(batch), {{0, 3, true, true}, {3, 0,
true, true}});
+ TestSegments(segmenter, ExecSpan(batch), {{0, 3, true, true}});
ARROW_RETURN_NOT_OK(segmenter->Reset());
}
return Status::OK();
@@ -893,10 +868,9 @@ TEST(RowSegmenter, RowConstantBatch) {
constexpr size_t n = 3;
std::vector<TypeHolder> types = {int32(), int32(), int32()};
auto full_batch = ExecBatchFromJSON(types, "[[1, 1, 1], [2, 2, 2], [3, 3,
3]]");
- std::vector<Segment> expected_segments_for_size_0 = {{0, 3, true, true},
- {3, 0, true, true}};
+ std::vector<Segment> expected_segments_for_size_0 = {{0, 3, true, true}};
std::vector<Segment> expected_segments = {
- {0, 1, false, true}, {1, 1, false, false}, {2, 1, true, false}, {3, 0,
true, true}};
+ {0, 1, false, true}, {1, 1, false, false}, {2, 1, true, false}};
auto test_by_size = [&](size_t size) -> Status {
SCOPED_TRACE("constant-batch with " + ToChars(size) + " key(s)");
std::vector<Datum> values(full_batch.values.begin(),
diff --git a/cpp/src/arrow/compute/row/grouper.cc
b/cpp/src/arrow/compute/row/grouper.cc
index 2b79539a3b..02ed186449 100644
--- a/cpp/src/arrow/compute/row/grouper.cc
+++ b/cpp/src/arrow/compute/row/grouper.cc
@@ -17,6 +17,7 @@
#include "arrow/compute/row/grouper.h"
+#include <iostream>
#include <memory>
#include <mutex>
#include <type_traits>
@@ -54,13 +55,8 @@ using group_id_t =
std::remove_const<decltype(kNoGroupId)>::type;
using GroupIdType = CTypeTraits<group_id_t>::ArrowType;
auto g_group_id_type = std::make_shared<GroupIdType>();
-inline const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset =
0) {
- DCHECK_GT(data.type->byte_width(), 0);
- int64_t absolute_byte_offset = (data.offset + offset) *
data.type->byte_width();
- return data.GetValues<uint8_t>(1, absolute_byte_offset);
-}
-
template <typename Value>
+ARROW_DEPRECATED("Deprecated in 18.0.0 along with GetSegments.")
Status CheckForGetNextSegment(const std::vector<Value>& values, int64_t length,
int64_t offset, const std::vector<TypeHolder>&
key_types) {
if (offset < 0 || offset > length) {
@@ -82,11 +78,22 @@ Status CheckForGetNextSegment(const std::vector<Value>&
values, int64_t length,
}
template <typename Batch>
+ARROW_DEPRECATED("Deprecated in 18.0.0 along with GetSegments.")
enable_if_t<std::is_same<Batch, ExecSpan>::value || std::is_same<Batch,
ExecBatch>::value,
- Status>
-CheckForGetNextSegment(const Batch& batch, int64_t offset,
- const std::vector<TypeHolder>& key_types) {
+ Status> CheckForGetNextSegment(const Batch& batch, int64_t offset,
+ const std::vector<TypeHolder>&
key_types) {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
return CheckForGetNextSegment(batch.values, batch.length, offset, key_types);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+}
+
+Status CheckForGetSegments(const ExecSpan& batch,
+ const std::vector<TypeHolder>& key_types) {
+ // TODO: Move the implementation of CheckForGetNextSegment here once we
remove the
+ // deprecated functions.
+ ARROW_SUPPRESS_DEPRECATION_WARNING
+ return CheckForGetNextSegment(batch, 0, key_types);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
}
struct BaseRowSegmenter : public RowSegmenter {
@@ -102,21 +109,6 @@ Segment MakeSegment(int64_t batch_length, int64_t offset,
int64_t length, bool e
return Segment{offset, length, offset + length >= batch_length, extends};
}
-// Used by SimpleKeySegmenter::GetNextSegment to find the match-length of a
value within a
-// fixed-width buffer
-int64_t GetMatchLength(const uint8_t* match_bytes, int64_t match_width,
- const uint8_t* array_bytes, int64_t offset, int64_t
length) {
- int64_t cursor, byte_cursor;
- for (cursor = offset, byte_cursor = match_width * cursor; cursor < length;
- cursor++, byte_cursor += match_width) {
- if (memcmp(match_bytes, array_bytes + byte_cursor,
- static_cast<size_t>(match_width)) != 0) {
- break;
- }
- }
- return std::min(cursor, length) - offset;
-}
-
using ExtendFunc = std::function<bool(const void*)>;
constexpr bool kDefaultExtends = true; // by default, the first segment
extends
constexpr bool kEmptyExtends = true; // an empty segment extends too
@@ -130,9 +122,22 @@ struct NoKeysSegmenter : public BaseRowSegmenter {
Status Reset() override { return Status::OK(); }
+ ARROW_DEPRECATED("Deprecated in 18.0.0. Use GetSegments instead.")
Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {}));
return MakeSegment(batch.length, offset, batch.length - offset,
kDefaultExtends);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+ }
+
+ Result<std::vector<Segment>> GetSegments(const ExecSpan& batch) override {
+ RETURN_NOT_OK(CheckForGetSegments(batch, {}));
+
+ if (batch.length == 0) {
+ return std::vector<Segment>{};
+ }
+ return std::vector<Segment>{
+ MakeSegment(batch.length, 0, batch.length - 0, kDefaultExtends)};
}
};
@@ -147,13 +152,6 @@ struct SimpleKeySegmenter : public BaseRowSegmenter {
save_key_data_(static_cast<size_t>(key_type_.type->byte_width())),
extend_was_called_(false) {}
- Status CheckType(const DataType& type) {
- if (!is_fixed_width(type)) {
- return Status::Invalid("SimpleKeySegmenter does not support type ",
type);
- }
- return Status::OK();
- }
-
Status Reset() override {
extend_was_called_ = false;
return Status::OK();
@@ -161,7 +159,8 @@ struct SimpleKeySegmenter : public BaseRowSegmenter {
// Checks whether the given grouping data extends the current segment, i.e.,
is equal to
// previously seen grouping data, which is updated with each invocation.
- bool Extend(const void* data) {
+ ARROW_DEPRECATED("Deprecated in 18.0.0 along with GetSegments.")
+ bool ExtendDeprecated(const void* data) {
bool extends = !extend_was_called_
? kDefaultExtends
: 0 == memcmp(save_key_data_.data(), data,
save_key_data_.size());
@@ -170,42 +169,136 @@ struct SimpleKeySegmenter : public BaseRowSegmenter {
return extends;
}
- Result<Segment> GetNextSegment(const Scalar& scalar, int64_t offset, int64_t
length) {
+ ARROW_DEPRECATED("Deprecated in 18.0.0 along with GetSegments.")
+ Result<Segment> GetNextSegmentDeprecated(const Scalar& scalar, int64_t
offset,
+ int64_t length) {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
ARROW_RETURN_NOT_OK(CheckType(*scalar.type));
if (!scalar.is_valid) {
return Status::Invalid("segmenting an invalid scalar");
}
auto data = checked_cast<const PrimitiveScalarBase&>(scalar).data();
- bool extends = length > 0 ? Extend(data) : kEmptyExtends;
+ bool extends = length > 0 ? ExtendDeprecated(data) : kEmptyExtends;
return MakeSegment(length, offset, length, extends);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
}
- Result<Segment> GetNextSegment(const DataType& array_type, const uint8_t*
array_bytes,
- int64_t offset, int64_t length) {
+ ARROW_DEPRECATED("Deprecated in 18.0.0 along with GetSegments.")
+ Result<Segment> GetNextSegmentDeprecated(const DataType& array_type,
+ const uint8_t* array_bytes, int64_t
offset,
+ int64_t length) {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
RETURN_NOT_OK(CheckType(array_type));
DCHECK_LE(offset, length);
int64_t byte_width = array_type.byte_width();
int64_t match_length = GetMatchLength(array_bytes + offset * byte_width,
byte_width,
array_bytes, offset, length);
- bool extends = length > 0 ? Extend(array_bytes + offset * byte_width) :
kEmptyExtends;
+ bool extends =
+ length > 0 ? ExtendDeprecated(array_bytes + offset * byte_width) :
kEmptyExtends;
return MakeSegment(length, offset, match_length, extends);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
}
Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {key_type_}));
if (offset == batch.length) {
return MakeSegment(batch.length, offset, 0, kEmptyExtends);
}
const auto& value = batch.values[0];
if (value.is_scalar()) {
- return GetNextSegment(*value.scalar, offset, batch.length);
+ return GetNextSegmentDeprecated(*value.scalar, offset, batch.length);
}
ARROW_DCHECK(value.is_array());
const auto& array = value.array;
if (array.GetNullCount() > 0) {
return Status::NotImplemented("segmenting a nullable array");
}
- return GetNextSegment(*array.type, GetValuesAsBytes(array), offset,
batch.length);
+ return GetNextSegmentDeprecated(*array.type, GetValuesAsBytes(array),
offset,
+ batch.length);
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+ }
+
+ Result<std::vector<Segment>> GetSegments(const ExecSpan& batch) override {
+ RETURN_NOT_OK(CheckForGetSegments(batch, {key_type_}));
+
+ if (batch.length == 0) {
+ return std::vector<Segment>{};
+ }
+
+ const auto& value = batch.values[0];
+ RETURN_NOT_OK(CheckType(*value.type()));
+
+ std::vector<Segment> segments;
+ const void* key_data;
+ if (value.is_scalar()) {
+ const auto& scalar = *value.scalar;
+ DCHECK(scalar.is_valid);
+ key_data = checked_cast<const PrimitiveScalarBase&>(scalar).data();
+ bool extends = Extend(key_data);
+ segments.push_back(MakeSegment(batch.length, 0, batch.length, extends));
+ } else {
+ DCHECK(value.is_array());
+ const auto& array = value.array;
+ DCHECK_EQ(array.GetNullCount(), 0);
+ auto data = GetValuesAsBytes(array);
+ int64_t byte_width = array.type->byte_width();
+ int64_t offset = 0;
+ bool extends = Extend(data);
+ while (offset < array.length) {
+ int64_t match_length = GetMatchLength(data + offset * byte_width,
byte_width,
+ data, offset, array.length);
+ segments.push_back(MakeSegment(array.length, offset, match_length,
+ offset == 0 ? extends : false));
+ offset += match_length;
+ }
+ key_data = data + (array.length - 1) * byte_width;
+ }
+
+ SaveKeyData(key_data);
+
+ return segments;
+ }
+
+ private:
+ static Status CheckType(const DataType& type) {
+ if (!is_fixed_width(type)) {
+ return Status::Invalid("SimpleKeySegmenter does not support type ",
type);
+ }
+ return Status::OK();
+ }
+
+ static const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset
= 0) {
+ DCHECK_GT(data.type->byte_width(), 0);
+ int64_t absolute_byte_offset = (data.offset + offset) *
data.type->byte_width();
+ return data.GetValues<uint8_t>(1, absolute_byte_offset);
+ }
+
+ // Find the match-length of a value within a fixed-width buffer
+ static int64_t GetMatchLength(const uint8_t* match_bytes, int64_t
match_width,
+ const uint8_t* array_bytes, int64_t offset,
+ int64_t length) {
+ int64_t cursor, byte_cursor;
+ for (cursor = offset, byte_cursor = match_width * cursor; cursor < length;
+ cursor++, byte_cursor += match_width) {
+ if (memcmp(match_bytes, array_bytes + byte_cursor,
+ static_cast<size_t>(match_width)) != 0) {
+ break;
+ }
+ }
+ return std::min(cursor, length) - offset;
+ }
+
+ bool Extend(const void* data) {
+ if (ARROW_PREDICT_FALSE(!extend_was_called_)) {
+ extend_was_called_ = true;
+ return kDefaultExtends;
+ }
+ return 0 == memcmp(save_key_data_.data(), data, save_key_data_.size());
+ }
+
+ void SaveKeyData(const void* data) {
+ memcpy(save_key_data_.data(), data, save_key_data_.size());
}
private:
@@ -233,6 +326,7 @@ struct AnyKeysSegmenter : public BaseRowSegmenter {
return Status::OK();
}
+ ARROW_DEPRECATED("Deprecated in 18.0.0 along with GetSegments.")
bool Extend(const void* data) {
auto group_id = *static_cast<const group_id_t*>(data);
bool extends =
@@ -241,24 +335,9 @@ struct AnyKeysSegmenter : public BaseRowSegmenter {
return extends;
}
- // Runs the grouper on a single row. This is used to determine the group id
of the
- // first row of a new segment to see if it extends the previous segment.
- template <typename Batch>
- Result<group_id_t> MapGroupIdAt(const Batch& batch, int64_t offset) {
- ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset,
- /*length=*/1));
- if (!datum.is_array()) {
- return Status::Invalid("accessing unsupported datum kind ",
datum.kind());
- }
- const std::shared_ptr<ArrayData>& data = datum.array();
- ARROW_DCHECK(data->GetNullCount() == 0);
- DCHECK_EQ(data->type->id(), GroupIdType::type_id);
- DCHECK_EQ(1, data->length);
- const group_id_t* values = data->GetValues<group_id_t>(1);
- return values[0];
- }
-
+ ARROW_DEPRECATED("Deprecated in 18.0.0. Use GetSegments instead.")
Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_SUPPRESS_DEPRECATION_WARNING
ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, key_types_));
if (offset == batch.length) {
return MakeSegment(batch.length, offset, 0, kEmptyExtends);
@@ -273,7 +352,7 @@ struct AnyKeysSegmenter : public BaseRowSegmenter {
};
// resetting drops grouper's group-ids, freeing-up memory for the next
segment
ARROW_RETURN_NOT_OK(grouper_->Reset());
- // GH-34475: cache the grouper-consume result across invocations of
GetNextSegment
+
ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset));
if (datum.is_array()) {
// `data` is an array whose index-0 corresponds to index `offset` of
`batch`
@@ -292,6 +371,76 @@ struct AnyKeysSegmenter : public BaseRowSegmenter {
} else {
return Status::Invalid("segmenting unsupported datum kind ",
datum.kind());
}
+ ARROW_UNSUPPRESS_DEPRECATION_WARNING
+ }
+
+ Result<std::vector<Segment>> GetSegments(const ExecSpan& batch) override {
+ RETURN_NOT_OK(CheckForGetSegments(batch, {key_types_}));
+
+ if (batch.length == 0) {
+ return std::vector<Segment>{};
+ }
+
+ // determine if the first segment in this batch extends the last segment
in the
+ // previous batch
+ bool extends = kDefaultExtends;
+ if (save_group_id_ != kNoGroupId) {
+ // the group id must be computed prior to resetting the grouper, since
it is
+ // compared to save_group_id_, and after resetting the grouper produces
incomparable
+ // group ids
+ ARROW_ASSIGN_OR_RAISE(auto group_id, MapGroupIdAt(batch));
+ // it "extends" unless the group id differs from the last group id
+ extends = (group_id == save_group_id_);
+ }
+
+ // resetting drops grouper's group-ids, freeing-up memory for the next
segment
+ RETURN_NOT_OK(grouper_->Reset());
+
+ std::vector<Segment> segments;
+ ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch));
+ DCHECK(datum.is_array());
+ // `data` is an array whose index-0 corresponds to index `offset` of
`batch`
+ const std::shared_ptr<ArrayData>& data = datum.array();
+ DCHECK_EQ(data->length, batch.length);
+ DCHECK_EQ(data->GetNullCount(), 0);
+ DCHECK_EQ(data->type->id(), GroupIdType::type_id);
+ const group_id_t* group_ids = data->GetValues<group_id_t>(1);
+ int64_t current_group_offset = 0;
+ int64_t cursor;
+ for (cursor = 1; cursor < data->length; ++cursor) {
+ if (group_ids[cursor] != group_ids[current_group_offset]) {
+ segments.push_back(MakeSegment(batch.length, current_group_offset,
+ cursor - current_group_offset,
+ current_group_offset == 0 ? extends :
false));
+ current_group_offset = cursor;
+ }
+ }
+ segments.push_back(MakeSegment(batch.length, current_group_offset,
+ cursor - current_group_offset,
+ current_group_offset == 0 ? extends :
false));
+
+ // update the save_group_id_ to the last group id in this batch
+ save_group_id_ = group_ids[batch.length - 1];
+
+ return segments;
+ }
+
+ private:
+ // Runs the grouper on a single row. This is used to determine the group id
of the
+ // first row of a new segment to see if it extends the previous segment.
+ template <typename Batch>
+ Result<group_id_t> MapGroupIdAt(const Batch& batch, int64_t offset = 0) {
+ ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset,
+ /*length=*/1));
+ if (!datum.is_array()) {
+ return Status::Invalid("accessing unsupported datum kind ",
datum.kind());
+ }
+ const std::shared_ptr<ArrayData>& data = datum.array();
+ ARROW_DCHECK(data->GetNullCount() == 0);
+ DCHECK_EQ(data->type->id(), GroupIdType::type_id);
+ DCHECK_EQ(1, data->length);
+ const group_id_t* values = data->GetValues<group_id_t>(1);
+ return values[0];
}
private:
diff --git a/cpp/src/arrow/compute/row/grouper.h
b/cpp/src/arrow/compute/row/grouper.h
index 1d2aaae9df..345bc62924 100644
--- a/cpp/src/arrow/compute/row/grouper.h
+++ b/cpp/src/arrow/compute/row/grouper.h
@@ -97,7 +97,12 @@ class ARROW_EXPORT RowSegmenter {
virtual Status Reset() = 0;
/// \brief Get the next segment for the given batch starting from the given
offset
+ /// DEPRECATED: Due to its inefficiency, use GetSegments instead.
+ ARROW_DEPRECATED("Deprecated in 18.0.0. Use GetSegments instead.")
virtual Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t
offset) = 0;
+
+ /// \brief Get all segments for the given batch
+ virtual Result<std::vector<Segment>> GetSegments(const ExecSpan& batch) = 0;
};
/// Consumes batches of keys and yields batches of the group ids.