This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new ea4a405 ARROW-10051: [C++][Compute] Move kernel state when merging
ea4a405 is described below
commit ea4a405e1e988746632150167b936f8b17557e44
Author: Yibo Cai <[email protected]>
AuthorDate: Thu Sep 24 11:39:48 2020 +0200
ARROW-10051: [C++][Compute] Move kernel state when merging
Aggregate kernel consumes one batch and output a `state`, which is then
merged with other batches. Currently, `state` parameter is defined as
`const KernelState&` in `merge` interface. It may cause unnecessary data
copying for kernels with non-trivial `states`.
E.g., mode kernel maintains a value:count map in its `state` structure,
by removing `const`, we can move the map rather than copy it.
Closes #8232 from cyb70289/agg-opt
Authored-by: Yibo Cai <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/compute/exec.cc | 2 +-
cpp/src/arrow/compute/kernel.h | 2 +-
cpp/src/arrow/compute/kernels/aggregate_basic.cc | 6 +++---
cpp/src/arrow/compute/kernels/aggregate_basic_internal.h | 6 +++---
cpp/src/arrow/compute/kernels/aggregate_mode.cc | 10 +++++-----
5 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc
index 435d7dd..71bbc34 100644
--- a/cpp/src/arrow/compute/exec.cc
+++ b/cpp/src/arrow/compute/exec.cc
@@ -885,7 +885,7 @@ class ScalarAggExecutor : public
FunctionExecutorImpl<ScalarAggregateFunction> {
kernel_->consume(&batch_ctx, batch);
ARROW_CTX_RETURN_IF_ERROR(&batch_ctx);
- kernel_->merge(&kernel_ctx_, *batch_state, state_.get());
+ kernel_->merge(&kernel_ctx_, std::move(*batch_state), state_.get());
ARROW_CTX_RETURN_IF_ERROR(&kernel_ctx_);
return Status::OK();
}
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index 3fb6947..67cb5df 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -664,7 +664,7 @@ struct VectorKernel : public ArrayKernel {
using ScalarAggregateConsume = std::function<void(KernelContext*, const
ExecBatch&)>;
using ScalarAggregateMerge =
- std::function<void(KernelContext*, const KernelState&, KernelState*)>;
+ std::function<void(KernelContext*, KernelState&&, KernelState*)>;
// Finalize returns Datum to permit multiple return values
using ScalarAggregateFinalize = std::function<void(KernelContext*, Datum*)>;
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
index 33afd68..94914d0 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc
@@ -30,8 +30,8 @@ void AggregateConsume(KernelContext* ctx, const ExecBatch&
batch) {
checked_cast<ScalarAggregator*>(ctx->state())->Consume(ctx, batch);
}
-void AggregateMerge(KernelContext* ctx, const KernelState& src, KernelState*
dst) {
- checked_cast<ScalarAggregator*>(dst)->MergeFrom(ctx, src);
+void AggregateMerge(KernelContext* ctx, KernelState&& src, KernelState* dst) {
+ checked_cast<ScalarAggregator*>(dst)->MergeFrom(ctx, std::move(src));
}
void AggregateFinalize(KernelContext* ctx, Datum* out) {
@@ -51,7 +51,7 @@ struct CountImpl : public ScalarAggregator {
this->non_nulls += input.length - nulls;
}
- void MergeFrom(KernelContext*, const KernelState& src) override {
+ void MergeFrom(KernelContext*, KernelState&& src) override {
const auto& other_state = checked_cast<const CountImpl&>(src);
this->non_nulls += other_state.non_nulls;
this->nulls += other_state.nulls;
diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
index 3776b20..cd8390e 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
+++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
@@ -31,7 +31,7 @@ namespace aggregate {
struct ScalarAggregator : public KernelState {
virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0;
- virtual void MergeFrom(KernelContext* ctx, const KernelState& src) = 0;
+ virtual void MergeFrom(KernelContext* ctx, KernelState&& src) = 0;
virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
};
@@ -260,7 +260,7 @@ struct SumImpl : public ScalarAggregator {
this->state.Consume(ArrayType(batch[0].array()));
}
- void MergeFrom(KernelContext*, const KernelState& src) override {
+ void MergeFrom(KernelContext*, KernelState&& src) override {
const auto& other = checked_cast<const ThisType&>(src);
this->state += other.state;
}
@@ -433,7 +433,7 @@ struct MinMaxImpl : public ScalarAggregator {
this->state = local;
}
- void MergeFrom(KernelContext*, const KernelState& src) override {
+ void MergeFrom(KernelContext*, KernelState&& src) override {
const auto& other = checked_cast<const ThisType&>(src);
this->state += other.state;
}
diff --git a/cpp/src/arrow/compute/kernels/aggregate_mode.cc
b/cpp/src/arrow/compute/kernels/aggregate_mode.cc
index 7c93fb9..aadf1ce 100644
--- a/cpp/src/arrow/compute/kernels/aggregate_mode.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate_mode.cc
@@ -155,9 +155,9 @@ struct ModeState {
using ThisType = ModeState<ArrowType>;
using CType = typename ArrowType::c_type;
- void MergeFrom(const ThisType& state) {
+ void MergeFrom(ThisType&& state) {
if (this->value_counts.empty()) {
- this->value_counts = state.value_counts;
+ this->value_counts = std::move(state.value_counts);
} else {
for (const auto& value_count : state.value_counts) {
auto value = value_count.first;
@@ -205,9 +205,9 @@ struct ModeImpl : public ScalarAggregator {
this->state.value_counts = CountValues(array, this->state.nan_count);
}
- void MergeFrom(KernelContext*, const KernelState& src) override {
- const auto& other = checked_cast<const ThisType&>(src);
- this->state.MergeFrom(other.state);
+ void MergeFrom(KernelContext*, KernelState&& src) override {
+ auto& other = checked_cast<ThisType&>(src);
+ this->state.MergeFrom(std::move(other.state));
}
void Finalize(KernelContext*, Datum* out) override {