westonpace commented on code in PR #34311:
URL: https://github.com/apache/arrow/pull/34311#discussion_r1128841550
##########
cpp/src/arrow/compute/exec/aggregate_node.cc:
##########
@@ -35,6 +36,25 @@
#include "arrow/util/thread_pool.h"
#include "arrow/util/tracing_internal.h"
+// This file implements both regular and segmented group-by aggregation, which
is a
+// generalization of ordered aggregation in which the key columns are not
required to be
+// ordered.
+//
+// In (regular) group-by aggregation, the input rows are partitioned into
groups using a
+// set of columns called keys, where in a given group each row has the same
values for
+// these columns. In segmented group-by aggregation, a second set of columns
called
+// segment-keys is used to refine the partitioning. However, segment-keys are
different in
+// that they partition only consecutive rows into a single group. Such a
partition of
+// consecutive rows is called a segment group. For example, consider a column
X with
+// values [A, B, A] at row-indices [0, 1, 2]. A regular group-by aggregation
with keys [X]
+// yields a row-index partitioning [[0, 2], [1]] whereas a segmented-group-by
aggregation
+// with segment-keys [X] yields [[0], [1], [2]].
Review Comment:
Minor nit: This example could be slightly improved I think if you used `[A,
A, B, A]` so that readers could see that the segmented group by still does
segment.
##########
cpp/src/arrow/compute/exec/aggregate_node.cc:
##########
@@ -584,29 +797,52 @@ class GroupByNode : public ExecNode, public TracedNode {
ARROW_ASSIGN_OR_RAISE(out_data_, Finalize());
int64_t num_output_batches = bit_util::CeilDiv(out_data_.length,
output_batch_size());
- RETURN_NOT_OK(output_->InputFinished(this,
static_cast<int>(num_output_batches)));
- return plan_->query_context()->StartTaskGroup(output_task_group_id_,
- num_output_batches);
+ total_output_batches_ += static_cast<int>(num_output_batches);
+ if (is_last) {
+ ARROW_RETURN_NOT_OK(output_->InputFinished(this, total_output_batches_));
+
RETURN_NOT_OK(plan_->query_context()->StartTaskGroup(output_task_group_id_,
+
num_output_batches));
+ } else {
+ for (int64_t i = 0; i < num_output_batches; i++) {
+ ARROW_RETURN_NOT_OK(OutputNthBatch(i));
+ }
+ ARROW_RETURN_NOT_OK(ResetKernelStates());
+ }
+ return Status::OK();
}
Status InputReceived(ExecNode* input, ExecBatch batch) override {
auto scope = TraceInputReceived(batch);
DCHECK_EQ(input, inputs_[0]);
- ARROW_RETURN_NOT_OK(Consume(ExecSpan(batch)));
+ auto handler = [this](const ExecBatch& full_batch, const Segment& segment)
{
+ if (!segment.extends && segment.offset == 0)
RETURN_NOT_OK(OutputResult(false));
+ // This is not zero copy - we should refactor the code to pass
+ // offset and length to Consume to avoid copying here
Review Comment:
Same as above. `full_batch.Slice` is zero-copy I believe.
##########
cpp/src/arrow/compute/exec/aggregate_node.cc:
##########
@@ -169,20 +185,79 @@ void AggregatesToString(std::stringstream* ss, const
Schema& input_schema,
*ss << ']';
}
+// Handle the input batch
+// If a segment is closed by this batch, then we output the aggregation for
the segment
+// If a segment is not closed by this batch, then we add the batch to the
segment
+template <typename BatchHandler>
+Status HandleSegments(std::unique_ptr<RowSegmenter>& segmenter, const
ExecBatch& batch,
Review Comment:
```suggestion
Status HandleSegments(RowSegmenter* segmenter, const ExecBatch& batch,
```
Prefer pointer over mutable reference.
##########
cpp/src/arrow/compute/row/grouper.cc:
##########
@@ -39,12 +43,521 @@
namespace arrow {
using internal::checked_cast;
+using internal::PrimitiveScalarBase;
namespace compute {
namespace {
-struct GrouperImpl : Grouper {
+constexpr uint32_t kNoGroupId = std::numeric_limits<uint32_t>::max();
+
+using group_id_t = std::remove_const<decltype(kNoGroupId)>::type;
+using GroupIdType = CTypeTraits<group_id_t>::ArrowType;
+auto group_id_type = std::make_shared<GroupIdType>();
+
+inline const uint8_t* GetValuesAsBytes(const ArrayData& data, int64_t offset =
0) {
+ DCHECK_GT(data.type->byte_width(), 0);
+ int64_t absolute_byte_offset = (data.offset + offset) *
data.type->byte_width();
+ return data.GetValues<uint8_t>(1, absolute_byte_offset);
+}
+
+inline const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset =
0) {
+ DCHECK_GT(data.type->byte_width(), 0);
+ int64_t absolute_byte_offset = (data.offset + offset) *
data.type->byte_width();
+ return data.GetValues<uint8_t>(1, absolute_byte_offset);
+}
+
+template <typename Value>
+Status CheckForGetNextSegment(const std::vector<Value>& values, int64_t length,
+ int64_t offset, const std::vector<TypeHolder>&
key_types) {
+ if (offset < 0 || offset > length) {
+ return Status::Invalid("invalid grouping segmenter offset: ", offset);
+ }
+ if (values.size() != key_types.size()) {
+ return Status::Invalid("expected batch size ", key_types.size(), " but got
",
+ values.size());
+ }
+ for (size_t i = 0; i < key_types.size(); i++) {
+ const auto& value = values[i];
+ const auto& key_type = key_types[i];
+ if (*value.type() != *key_type.type) {
+ return Status::Invalid("expected batch value ", i, " of type ",
*key_type.type,
+ " but got ", *value.type());
+ }
+ }
+ return Status::OK();
+}
+
+template <typename Batch>
+enable_if_t<std::is_same<Batch, ExecSpan>::value || std::is_same<Batch,
ExecBatch>::value,
+ Status>
+CheckForGetNextSegment(const Batch& batch, int64_t offset,
+ const std::vector<TypeHolder>& key_types) {
+ return CheckForGetNextSegment(batch.values, batch.length, offset, key_types);
+}
+
+struct BaseGroupingSegmenter : public GroupingSegmenter {
+ explicit BaseGroupingSegmenter(const std::vector<TypeHolder>& key_types)
+ : key_types_(key_types) {}
+
+ const std::vector<TypeHolder>& key_types() const override { return
key_types_; }
+
+ std::vector<TypeHolder> key_types_;
+};
+
+GroupingSegment MakeSegment(int64_t batch_length, int64_t offset, int64_t
length,
+ bool extends) {
+ return GroupingSegment{offset, length, offset + length >= batch_length,
extends};
+}
+
+int64_t GetMatchLength(const uint8_t* match_bytes, int64_t match_width,
Review Comment:
To be fair, it's not clear yet that the performance cost here is
significant, and will depend on I/O. Once we implement "multi-threaded
segmented aggregation" then it should be fairly clear. The "find the next
segment" portion will be the spot that we have to serialize (we can consume
batches, compute aggregates, and output results in parallel). So it should be
clear from a trace (once we have tracing a bit better supported) if this spot
is becoming a bottleneck.
I am in favor of deferring the optimization to future PRs.
##########
cpp/src/arrow/compute/row/grouper.cc:
##########
@@ -39,12 +43,336 @@
namespace arrow {
using internal::checked_cast;
+using internal::PrimitiveScalarBase;
namespace compute {
namespace {
-struct GrouperImpl : Grouper {
+constexpr uint32_t kNoGroupId = std::numeric_limits<uint32_t>::max();
+
+using group_id_t = std::remove_const<decltype(kNoGroupId)>::type;
+using GroupIdType = CTypeTraits<group_id_t>::ArrowType;
+auto g_group_id_type = std::make_shared<GroupIdType>();
+
+inline const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset =
0) {
+ DCHECK_GT(data.type->byte_width(), 0);
+ int64_t absolute_byte_offset = (data.offset + offset) *
data.type->byte_width();
+ return data.GetValues<uint8_t>(1, absolute_byte_offset);
+}
+
+template <typename Value>
+Status CheckForGetNextSegment(const std::vector<Value>& values, int64_t length,
+ int64_t offset, const std::vector<TypeHolder>&
key_types) {
+ if (offset < 0 || offset > length) {
Review Comment:
I agree. My general rule of thumb is "would a user understand this error?"
or "could this error be triggered by invalid user input?" In this case, I
think the answer is "no" and so these could be `DCHECK`. I wouldn't count
"custom node developers" as users.
##########
cpp/src/arrow/compute/row/grouper.cc:
##########
@@ -39,12 +43,330 @@
namespace arrow {
using internal::checked_cast;
+using internal::PrimitiveScalarBase;
namespace compute {
namespace {
-struct GrouperImpl : Grouper {
+constexpr uint32_t kNoGroupId = std::numeric_limits<uint32_t>::max();
+
+using group_id_t = std::remove_const<decltype(kNoGroupId)>::type;
+using GroupIdType = CTypeTraits<group_id_t>::ArrowType;
+auto g_group_id_type = std::make_shared<GroupIdType>();
+
+inline const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset =
0) {
+ DCHECK_GT(data.type->byte_width(), 0);
+ int64_t absolute_byte_offset = (data.offset + offset) *
data.type->byte_width();
+ return data.GetValues<uint8_t>(1, absolute_byte_offset);
+}
+
+template <typename Value>
+Status CheckForGetNextSegment(const std::vector<Value>& values, int64_t length,
+ int64_t offset, const std::vector<TypeHolder>&
key_types) {
+ if (offset < 0 || offset > length) {
+ return Status::Invalid("invalid grouping segmenter offset: ", offset);
+ }
+ if (values.size() != key_types.size()) {
+ return Status::Invalid("expected batch size ", key_types.size(), " but got
",
+ values.size());
+ }
+ for (size_t i = 0; i < key_types.size(); i++) {
+ const auto& value = values[i];
+ const auto& key_type = key_types[i];
+ if (*value.type() != *key_type.type) {
+ return Status::Invalid("expected batch value ", i, " of type ",
*key_type.type,
+ " but got ", *value.type());
+ }
+ }
+ return Status::OK();
+}
+
+template <typename Batch>
+enable_if_t<std::is_same<Batch, ExecSpan>::value || std::is_same<Batch,
ExecBatch>::value,
+ Status>
+CheckForGetNextSegment(const Batch& batch, int64_t offset,
+ const std::vector<TypeHolder>& key_types) {
+ return CheckForGetNextSegment(batch.values, batch.length, offset, key_types);
+}
+
+struct BaseRowSegmenter : public RowSegmenter {
+ explicit BaseRowSegmenter(const std::vector<TypeHolder>& key_types)
+ : key_types_(key_types) {}
+
+ const std::vector<TypeHolder>& key_types() const override { return
key_types_; }
+
+ std::vector<TypeHolder> key_types_;
+};
+
+Segment MakeSegment(int64_t batch_length, int64_t offset, int64_t length, bool
extends) {
+ return Segment{offset, length, offset + length >= batch_length, extends};
+}
+
+int64_t GetMatchLength(const uint8_t* match_bytes, int64_t match_width,
+ const uint8_t* array_bytes, int64_t offset, int64_t
length) {
+ int64_t cursor, byte_cursor;
+ for (cursor = offset, byte_cursor = match_width * cursor; cursor < length;
+ cursor++, byte_cursor += match_width) {
+ if (memcmp(match_bytes, array_bytes + byte_cursor,
+ static_cast<size_t>(match_width)) != 0) {
+ break;
+ }
+ }
+ return std::min(cursor, length) - offset;
+}
+
+using ExtendFunc = std::function<bool(const void*)>;
+constexpr bool kDefaultExtends = true;
+constexpr bool kEmptyExtends = true;
+
+struct NoKeysSegmenter : public BaseRowSegmenter {
+ static std::unique_ptr<RowSegmenter> Make() {
+ return std::make_unique<NoKeysSegmenter>();
+ }
+
+ NoKeysSegmenter() : BaseRowSegmenter({}) {}
+
+ Status Reset() override { return Status::OK(); }
+
+ Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {}));
+ return MakeSegment(batch.length, offset, batch.length - offset,
kDefaultExtends);
+ }
+};
+
+struct SimpleKeySegmenter : public BaseRowSegmenter {
+ static Result<std::unique_ptr<RowSegmenter>> Make(TypeHolder key_type) {
+ return std::make_unique<SimpleKeySegmenter>(key_type);
+ }
+
+ explicit SimpleKeySegmenter(TypeHolder key_type)
+ : BaseRowSegmenter({key_type}), key_type_(key_types_[0]),
save_key_data_() {}
+
+ Status CheckType(const DataType& type) {
+ if (!is_fixed_width(type)) {
+ return Status::Invalid("SimpleKeySegmenter does not support type ",
type);
+ }
+ return Status::OK();
+ }
+
+ Status Reset() override {
+ save_key_data_.resize(0);
+ return Status::OK();
+ }
+
+ // Checks whether the given grouping data extends the current segment, i.e.,
is equal to
+ // previously seen grouping data, which is updated with each invocation.
+ bool Extend(const void* data) {
+ size_t byte_width = static_cast<size_t>(key_type_.type->byte_width());
+ bool extends = save_key_data_.size() != byte_width
+ ? kDefaultExtends
+ : 0 == memcmp(save_key_data_.data(), data, byte_width);
+ save_key_data_.resize(byte_width);
+ memcpy(save_key_data_.data(), data, byte_width);
+ return extends;
+ }
+
+ Result<Segment> GetNextSegment(const Scalar& scalar, int64_t offset, int64_t
length) {
+ ARROW_RETURN_NOT_OK(CheckType(*scalar.type));
+ if (!scalar.is_valid) {
+ return Status::Invalid("segmenting an invalid scalar");
+ }
+ auto data = checked_cast<const PrimitiveScalarBase&>(scalar).data();
+ bool extends = length > 0 ? Extend(data) : kEmptyExtends;
+ return MakeSegment(length, offset, length, extends);
+ }
+
+ Result<Segment> GetNextSegment(const DataType& array_type, const uint8_t*
array_bytes,
+ int64_t offset, int64_t length) {
+ RETURN_NOT_OK(CheckType(array_type));
+ int64_t byte_width = array_type.byte_width();
+ int64_t match_length = GetMatchLength(array_bytes + offset * byte_width,
byte_width,
+ array_bytes, offset, length);
+ bool extends = length > 0 ? Extend(array_bytes + offset * byte_width) :
kEmptyExtends;
+ return MakeSegment(length, offset, match_length, extends);
+ }
+
+ Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {key_type_}));
+ if (offset == batch.length) {
+ return MakeSegment(batch.length, offset, 0, kEmptyExtends);
+ }
+ const auto& value = batch.values[0];
+ if (value.is_scalar()) {
+ return GetNextSegment(*value.scalar, offset, batch.length);
+ }
+ ARROW_DCHECK(value.is_array());
+ const auto& array = value.array;
+ if (array.GetNullCount() > 0) {
+ return Status::NotImplemented("segmenting a nullable array");
+ }
+ return GetNextSegment(*array.type, GetValuesAsBytes(array), offset,
batch.length);
+ }
+
+ private:
+ TypeHolder key_type_;
+ std::vector<uint8_t> save_key_data_;
+};
+
+struct AnyKeysSegmenter : public BaseRowSegmenter {
+ static Result<std::unique_ptr<RowSegmenter>> Make(
+ const std::vector<TypeHolder>& key_types, ExecContext* ctx) {
+ ARROW_RETURN_NOT_OK(Grouper::Make(key_types, ctx)); // check types
+ return std::make_unique<AnyKeysSegmenter>(key_types, ctx);
+ }
+
+ AnyKeysSegmenter(const std::vector<TypeHolder>& key_types, ExecContext* ctx)
+ : BaseRowSegmenter(key_types),
+ ctx_(ctx),
+ grouper_(nullptr),
+ save_group_id_(kNoGroupId) {}
+
+ Status Reset() override {
Review Comment:
We should also keep in mind that "custom exec node developers" are a valid
persona to support. As we build up utilities like these we could in theory
even expose them to python (e.g. pyarrow users have asked for things like
ExecBatchBuilder and the row table in the past). I think it's ok to err a
little bit on the side of "a complete abstraction" as long as it doesn't go
overboard.
To help review length in the future, we could probably have reviewed the
segmenters independently and then come back and reviewed the changes to the
node itself. But that ship has sailed at this point :)
##########
cpp/src/arrow/compute/exec/aggregate_node.cc:
##########
@@ -169,20 +185,79 @@ void AggregatesToString(std::stringstream* ss, const
Schema& input_schema,
*ss << ']';
}
+// Handle the input batch
+// If a segment is closed by this batch, then we output the aggregation for
the segment
+// If a segment is not closed by this batch, then we add the batch to the
segment
Review Comment:
```suggestion
// Extract segments from a batch and run the given handler on them. Note
that the
// handle may be called on open segments which are not yet finished.
Typically a
// handler should accumulate those open segments until a closed segment is
reached.
```
##########
cpp/src/arrow/compute/exec/aggregate_node.cc:
##########
@@ -283,28 +393,47 @@ class ScalarAggregateNode : public ExecNode, public
TracedNode {
DCHECK_EQ(input, inputs_[0]);
auto thread_index = plan_->query_context()->GetThreadIndex();
-
- ARROW_RETURN_NOT_OK(DoConsume(ExecSpan(batch), thread_index));
+ auto handler = [this, thread_index](const ExecBatch& full_batch,
+ const Segment& segment) {
+ // (1) The segment is starting of a new segment group and points to
+ // the beginning of the batch, then it means no data in the batch belongs
+ // to the current segment group. We can output and reset kernel states.
+ if (!segment.extends && segment.offset == 0)
RETURN_NOT_OK(OutputResult(false));
+
+ // We add segment to the current segment group aggregation
+ // GH-34475: change to zero-copy slicing
+ auto exec_batch = full_batch.Slice(segment.offset, segment.length);
Review Comment:
Isn't `full_batch.Slice` a zero-copy operation? I'm not sure I understand
why we need a follow-up.
##########
cpp/src/arrow/compute/exec/aggregate_node.cc:
##########
@@ -169,35 +185,117 @@ void AggregatesToString(std::stringstream* ss, const
Schema& input_schema,
*ss << ']';
}
+template <typename BatchHandler>
+Status HandleSegments(std::unique_ptr<GroupingSegmenter>& segmenter,
+ const ExecBatch& batch, const std::vector<int>& ids,
+ const BatchHandler& handle_batch) {
+ int64_t offset = 0;
+ ARROW_ASSIGN_OR_RAISE(auto segment_exec_batch, batch.SelectValues(ids));
+ ExecSpan segment_batch(segment_exec_batch);
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto segment,
segmenter->GetNextSegment(segment_batch, offset));
+ if (segment.offset >= segment_batch.length) break; // condition of
no-next-segment
+ ARROW_RETURN_NOT_OK(handle_batch(batch, segment));
+ offset = segment.offset + segment.length;
+ }
+ return Status::OK();
+}
+
+Status GetScalarFields(std::vector<Datum>* values_ptr, const ExecBatch&
input_batch,
+ const std::vector<int>& field_ids) {
+ DCHECK_GT(input_batch.length, 0);
+ std::vector<Datum>& values = *values_ptr;
+ int64_t row = input_batch.length - 1;
+ values.clear();
+ values.resize(field_ids.size());
+ for (size_t i = 0; i < field_ids.size(); i++) {
+ const Datum& value = input_batch.values[field_ids[i]];
+ if (value.is_scalar()) {
Review Comment:
> Hmm.. does the input_batch not have an universal interface to at the value
of row i regardless of whether it is a constant or not?
It does not. `Datum` arose early on in the compute infrastructure and can
actually be quite a few things (chunked array, record batch, etc.) So, in that
context "get the value at index i" isn't universal. I wouldn't be surprised if
there is some helper function for this somewhere but I don't know what it is.
Typically data values are only accessed within compute kernels and often we
have completely separate kernels for working with scalars vs. arrays. And, for
the kernels that handle both, there are a lot of helper classes for "do x for
each item in scalar/array", but those are inside the compute kernels logic and
probably dealing with spans, etc.
Within Acero, a `Datum` will only be `Scalar` or `Array`. I think this is
probably fine. Maybe there is a longer term refactor to use something more
limited in Acero instead of `Datum` to help avoid this confusion.
##########
cpp/src/arrow/compute/row/grouper.cc:
##########
@@ -39,12 +43,330 @@
namespace arrow {
using internal::checked_cast;
+using internal::PrimitiveScalarBase;
namespace compute {
namespace {
-struct GrouperImpl : Grouper {
+constexpr uint32_t kNoGroupId = std::numeric_limits<uint32_t>::max();
+
+using group_id_t = std::remove_const<decltype(kNoGroupId)>::type;
+using GroupIdType = CTypeTraits<group_id_t>::ArrowType;
+auto g_group_id_type = std::make_shared<GroupIdType>();
+
+inline const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset =
0) {
+ DCHECK_GT(data.type->byte_width(), 0);
+ int64_t absolute_byte_offset = (data.offset + offset) *
data.type->byte_width();
+ return data.GetValues<uint8_t>(1, absolute_byte_offset);
+}
+
+template <typename Value>
+Status CheckForGetNextSegment(const std::vector<Value>& values, int64_t length,
+ int64_t offset, const std::vector<TypeHolder>&
key_types) {
+ if (offset < 0 || offset > length) {
+ return Status::Invalid("invalid grouping segmenter offset: ", offset);
+ }
+ if (values.size() != key_types.size()) {
+ return Status::Invalid("expected batch size ", key_types.size(), " but got
",
+ values.size());
+ }
+ for (size_t i = 0; i < key_types.size(); i++) {
+ const auto& value = values[i];
+ const auto& key_type = key_types[i];
+ if (*value.type() != *key_type.type) {
+ return Status::Invalid("expected batch value ", i, " of type ",
*key_type.type,
+ " but got ", *value.type());
+ }
+ }
+ return Status::OK();
+}
+
+template <typename Batch>
+enable_if_t<std::is_same<Batch, ExecSpan>::value || std::is_same<Batch,
ExecBatch>::value,
+ Status>
+CheckForGetNextSegment(const Batch& batch, int64_t offset,
+ const std::vector<TypeHolder>& key_types) {
+ return CheckForGetNextSegment(batch.values, batch.length, offset, key_types);
+}
+
+struct BaseRowSegmenter : public RowSegmenter {
+ explicit BaseRowSegmenter(const std::vector<TypeHolder>& key_types)
+ : key_types_(key_types) {}
+
+ const std::vector<TypeHolder>& key_types() const override { return
key_types_; }
+
+ std::vector<TypeHolder> key_types_;
+};
+
+Segment MakeSegment(int64_t batch_length, int64_t offset, int64_t length, bool
extends) {
+ return Segment{offset, length, offset + length >= batch_length, extends};
+}
+
+int64_t GetMatchLength(const uint8_t* match_bytes, int64_t match_width,
+ const uint8_t* array_bytes, int64_t offset, int64_t
length) {
+ int64_t cursor, byte_cursor;
+ for (cursor = offset, byte_cursor = match_width * cursor; cursor < length;
+ cursor++, byte_cursor += match_width) {
+ if (memcmp(match_bytes, array_bytes + byte_cursor,
+ static_cast<size_t>(match_width)) != 0) {
+ break;
+ }
+ }
+ return std::min(cursor, length) - offset;
+}
+
+using ExtendFunc = std::function<bool(const void*)>;
+constexpr bool kDefaultExtends = true;
+constexpr bool kEmptyExtends = true;
+
+struct NoKeysSegmenter : public BaseRowSegmenter {
+ static std::unique_ptr<RowSegmenter> Make() {
+ return std::make_unique<NoKeysSegmenter>();
+ }
+
+ NoKeysSegmenter() : BaseRowSegmenter({}) {}
+
+ Status Reset() override { return Status::OK(); }
+
+ Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {}));
+ return MakeSegment(batch.length, offset, batch.length - offset,
kDefaultExtends);
+ }
+};
+
+struct SimpleKeySegmenter : public BaseRowSegmenter {
+ static Result<std::unique_ptr<RowSegmenter>> Make(TypeHolder key_type) {
+ return std::make_unique<SimpleKeySegmenter>(key_type);
+ }
+
+ explicit SimpleKeySegmenter(TypeHolder key_type)
+ : BaseRowSegmenter({key_type}), key_type_(key_types_[0]),
save_key_data_() {}
+
+ Status CheckType(const DataType& type) {
+ if (!is_fixed_width(type)) {
+ return Status::Invalid("SimpleKeySegmenter does not support type ",
type);
+ }
+ return Status::OK();
+ }
+
+ Status Reset() override {
+ save_key_data_.resize(0);
+ return Status::OK();
+ }
+
+ // Checks whether the given grouping data extends the current segment, i.e.,
is equal to
+ // previously seen grouping data, which is updated with each invocation.
+ bool Extend(const void* data) {
+ size_t byte_width = static_cast<size_t>(key_type_.type->byte_width());
+ bool extends = save_key_data_.size() != byte_width
+ ? kDefaultExtends
+ : 0 == memcmp(save_key_data_.data(), data, byte_width);
+ save_key_data_.resize(byte_width);
+ memcpy(save_key_data_.data(), data, byte_width);
+ return extends;
+ }
+
+ Result<Segment> GetNextSegment(const Scalar& scalar, int64_t offset, int64_t
length) {
+ ARROW_RETURN_NOT_OK(CheckType(*scalar.type));
+ if (!scalar.is_valid) {
+ return Status::Invalid("segmenting an invalid scalar");
+ }
+ auto data = checked_cast<const PrimitiveScalarBase&>(scalar).data();
+ bool extends = length > 0 ? Extend(data) : kEmptyExtends;
+ return MakeSegment(length, offset, length, extends);
+ }
+
+ Result<Segment> GetNextSegment(const DataType& array_type, const uint8_t*
array_bytes,
+ int64_t offset, int64_t length) {
+ RETURN_NOT_OK(CheckType(array_type));
+ int64_t byte_width = array_type.byte_width();
+ int64_t match_length = GetMatchLength(array_bytes + offset * byte_width,
byte_width,
+ array_bytes, offset, length);
+ bool extends = length > 0 ? Extend(array_bytes + offset * byte_width) :
kEmptyExtends;
+ return MakeSegment(length, offset, match_length, extends);
+ }
+
+ Result<Segment> GetNextSegment(const ExecSpan& batch, int64_t offset)
override {
+ ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {key_type_}));
+ if (offset == batch.length) {
+ return MakeSegment(batch.length, offset, 0, kEmptyExtends);
+ }
+ const auto& value = batch.values[0];
+ if (value.is_scalar()) {
+ return GetNextSegment(*value.scalar, offset, batch.length);
+ }
+ ARROW_DCHECK(value.is_array());
+ const auto& array = value.array;
+ if (array.GetNullCount() > 0) {
+ return Status::NotImplemented("segmenting a nullable array");
+ }
+ return GetNextSegment(*array.type, GetValuesAsBytes(array), offset,
batch.length);
+ }
+
+ private:
+ TypeHolder key_type_;
+ std::vector<uint8_t> save_key_data_;
+};
+
+struct AnyKeysSegmenter : public BaseRowSegmenter {
+ static Result<std::unique_ptr<RowSegmenter>> Make(
+ const std::vector<TypeHolder>& key_types, ExecContext* ctx) {
+ ARROW_RETURN_NOT_OK(Grouper::Make(key_types, ctx)); // check types
+ return std::make_unique<AnyKeysSegmenter>(key_types, ctx);
+ }
+
+ AnyKeysSegmenter(const std::vector<TypeHolder>& key_types, ExecContext* ctx)
+ : BaseRowSegmenter(key_types),
+ ctx_(ctx),
+ grouper_(nullptr),
+ save_group_id_(kNoGroupId) {}
+
+ Status Reset() override {
+ grouper_ = nullptr;
+ save_group_id_ = kNoGroupId;
+ return Status::OK();
+ }
+
+ bool Extend(const void* data) {
+ auto group_id = *static_cast<const group_id_t*>(data);
+ bool extends =
+ save_group_id_ == kNoGroupId ? kDefaultExtends : save_group_id_ ==
group_id;
+ save_group_id_ = group_id;
+ return extends;
+ }
+
+ // Runs the grouper on a single row. This is used to determine the group id
of the
+ // first row of a new segment to see if it extends the previous segment.
+ template <typename Batch>
+ Result<group_id_t> MapGroupIdAt(const Batch& batch, int64_t offset) {
+ if (!grouper_) return kNoGroupId;
+ ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset,
+ /*length=*/1));
+ if (!datum.is_array()) {
+ return Status::Invalid("accessing unsupported datum kind ",
datum.kind());
+ }
+ const std::shared_ptr<ArrayData>& data = datum.array();
+ ARROW_DCHECK(data->GetNullCount() == 0);
+ DCHECK_EQ(data->type->id(), GroupIdType::type_id);
+ DCHECK_EQ(1, data->length);
+ const group_id_t* values = data->GetValues<group_id_t>(1);
Review Comment:
I'd enjoy a pattern like this someday:
https://github.com/apache/arrow/blob/main/cpp/src/arrow/compute/light_array.h#L118-L121
but for the time being @rtpsw is probably right. This is part of the
"implicit knowledge" required for arrow-c++.
##########
cpp/src/arrow/compute/exec/options.h:
##########
@@ -199,21 +199,39 @@ class ARROW_EXPORT ProjectNodeOptions : public
ExecNodeOptions {
std::vector<std::string> names;
};
-/// \brief Make a node which aggregates input batches, optionally grouped by
keys.
+/// \brief Make a node which aggregates input batches, optionally grouped by
keys and
+/// optionally segmented by segment-keys. Both keys and segment-keys determine
the group.
+/// However segment-keys are also used for determining grouping segments,
which should be
+/// large, and allow streaming a partial aggregation result after processing
each segment.
+/// One common use-case for segment-keys is ordered aggregation, in which the
segment-key
+/// attribute specifies a column with non-decreasing values or a
lexicographically-ordered
+/// set of such columns.
///
/// If the keys attribute is a non-empty vector, then each aggregate in
`aggregates` is
/// expected to be a HashAggregate function. If the keys attribute is an empty
vector,
/// then each aggregate is assumed to be a ScalarAggregate function.
+///
+/// If the segment_keys attribute is a non-empty vector, then segmented
aggregation, as
+/// described above, applies.
+///
+/// The keys and segment_keys vectors must be disjoint.
+///
+/// See also doc in `aggregate_node.cc`
Review Comment:
```suggestion
```
This documentation is for users. I'm not sure we should be directing users
to `aggregate_node.cc`. Also, it's not clear what doc this is referring to. I
think this is fine as it is without the "see also".
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]