cyb70289 commented on a change in pull request #9358:
URL: https://github.com/apache/arrow/pull/9358#discussion_r569193045
##########
File path: cpp/src/arrow/compute/kernels/aggregate_quantile.cc
##########
@@ -239,6 +229,243 @@ struct QuantileExecutor {
}
};
+// histogram approach with constant memory, only for integers within limited
value range
+template <typename InType>
+struct CountQuantiler {
+ using CType = typename InType::c_type;
+
+ CType min;
+ std::vector<uint64_t> counts; // counts[i]: # of values equals i + min
+
+ // indices to adjacent non-empty bins covering current quantile
+ struct AdjacentBins {
+ int left_index;
+ int right_index;
+ uint64_t total_count; // accumulated counts till left_index (inclusive)
+ };
+
+ CountQuantiler(CType min, CType max) {
+ uint32_t value_range = static_cast<uint32_t>(max - min) + 1;
+ DCHECK_LT(value_range, 1 << 30);
+ this->min = min;
+ this->counts.resize(value_range);
+ }
+
+ void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const QuantileOptions& options = QuantileState::Get(ctx);
+
+ // count values in all chunks, ignore nulls
+ const Datum& datum = batch[0];
+ const int64_t in_length = datum.length() - datum.null_count();
+ if (in_length > 0) {
+ for (auto& c : this->counts) c = 0;
+ for (const auto& array : datum.chunks()) {
+ const ArrayData& data = *array->data();
+ const CType* values = data.GetValues<CType>(1);
+ VisitSetBitRunsVoid(data.buffers[0], data.offset, data.length,
+ [&](int64_t pos, int64_t len) {
+ for (int64_t i = 0; i < len; ++i) {
+ ++this->counts[values[pos + i] - this->min];
+ }
+ });
+ }
+ }
+
+ // prepare out array
+ int64_t out_length = options.q.size();
+ if (in_length == 0) {
+ out_length = 0; // input is empty or only contains null, return empty
array
+ }
+ // out type depends on options
+ const bool is_datapoint = IsDataPoint(options);
+ const std::shared_ptr<DataType> out_type =
+ is_datapoint ? TypeTraits<InType>::type_singleton() : float64();
+ auto out_data = ArrayData::Make(out_type, out_length, 0);
+ out_data->buffers.resize(2, nullptr);
+
+ // calculate quantiles
+ if (out_length > 0) {
+ const auto out_bit_width =
checked_pointer_cast<NumberType>(out_type)->bit_width();
+ KERNEL_ASSIGN_OR_RAISE(out_data->buffers[1], ctx,
+ ctx->Allocate(out_length * out_bit_width / 8));
+
+ // find quantiles in ascending order
+ std::vector<int64_t> q_indices(out_length);
+ std::iota(q_indices.begin(), q_indices.end(), 0);
+ std::sort(q_indices.begin(), q_indices.end(),
+ [&options](int64_t left_index, int64_t right_index) {
+ return options.q[left_index] < options.q[right_index];
+ });
+
+ AdjacentBins bins{0, 0, this->counts[0]};
+ if (is_datapoint) {
+ CType* out_buffer = out_data->template GetMutableValues<CType>(1);
+ for (int64_t i = 0; i < out_length; ++i) {
+ const int64_t q_index = q_indices[i];
+ out_buffer[q_index] = GetQuantileAtDataPoint(
+ in_length, &bins, options.q[q_index], options.interpolation);
+ }
+ } else {
+ double* out_buffer = out_data->template GetMutableValues<double>(1);
+ for (int64_t i = 0; i < out_length; ++i) {
+ const int64_t q_index = q_indices[i];
+ out_buffer[q_index] = GetQuantileByInterp(in_length, &bins,
options.q[q_index],
+ options.interpolation);
+ }
+ }
+ }
+
+ *out = Datum(std::move(out_data));
+ }
+
+ // return quantile located exactly at some input data point
+ CType GetQuantileAtDataPoint(int64_t in_length, AdjacentBins* bins, double q,
+ enum QuantileOptions::Interpolation
interpolation) {
+ const uint64_t datapoint_index = QuantileToDataPoint(in_length, q,
interpolation);
+ while (datapoint_index >= bins->total_count &&
+ static_cast<size_t>(bins->left_index) < this->counts.size() - 1) {
+ ++bins->left_index;
+ bins->total_count += this->counts[bins->left_index];
+ }
+ DCHECK_LT(datapoint_index, bins->total_count);
+ return static_cast<CType>(bins->left_index + this->min);
+ }
+
+ // return quantile interpolated from adjacent input data points
+ double GetQuantileByInterp(int64_t in_length, AdjacentBins* bins, double q,
+ enum QuantileOptions::Interpolation
interpolation) {
+ const double index = (in_length - 1) * q;
+ const uint64_t index_floor = static_cast<uint64_t>(index);
+ const double fraction = index - index_floor;
+
+ while (index_floor >= bins->total_count &&
+ static_cast<size_t>(bins->left_index) < this->counts.size() - 1) {
+ ++bins->left_index;
+ bins->total_count += this->counts[bins->left_index];
+ }
+ DCHECK_LT(index_floor, bins->total_count);
+ const double lower_value = static_cast<double>(bins->left_index +
this->min);
+
+ // quantile lies in this bin, no interpolation needed
+ if (index <= bins->total_count - 1) {
+ return lower_value;
+ }
+
+ // quantile lies across two bins, locate next bin if not already done
+ DCHECK_EQ(index_floor, bins->total_count - 1);
+ if (bins->right_index <= bins->left_index) {
+ bins->right_index = bins->left_index + 1;
+ while (static_cast<size_t>(bins->right_index) < this->counts.size() - 1
&&
+ this->counts[bins->right_index] == 0) {
+ ++bins->right_index;
+ }
+ }
+ DCHECK_LT(static_cast<size_t>(bins->right_index), this->counts.size());
+ DCHECK_GT(this->counts[bins->right_index], 0);
+ const double higher_value = static_cast<double>(bins->right_index +
this->min);
+
+ if (interpolation == QuantileOptions::LINEAR) {
+ return fraction * higher_value + (1 - fraction) * lower_value;
+ } else if (interpolation == QuantileOptions::MIDPOINT) {
+ return lower_value / 2 + higher_value / 2;
+ } else {
+ DCHECK(false);
+ return NAN;
+ }
+ }
+};
+
+// histogram or sort approach per value range and size, only for integers
+template <typename InType>
+struct CountOrSortQuantiler {
+ using CType = typename InType::c_type;
+
+ void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // cross point to benefit from histogram approach
+ // parameters estimated from ad-hoc benchmarks manually
+ static constexpr int kMinArraySize = 65536 * sizeof(int) / sizeof(CType);
+ static constexpr int kMaxValueRange = 65536;
+
+ const Datum& datum = batch[0];
+ if (datum.length() - datum.null_count() >= kMinArraySize) {
Review comment:
Yes. It looks still desirable to find min/max as it may not cover 64k
range.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]