This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 04249b9137 GH-45216: [C++][Compute] Refactor Rank implementation
(#45217)
04249b9137 is described below
commit 04249b9137fe3943698f8d1f8ab513b125ea1d91
Author: Antoine Pitrou <[email protected]>
AuthorDate: Mon Jan 13 14:39:43 2025 +0100
GH-45216: [C++][Compute] Refactor Rank implementation (#45217)
### Rationale for this change
The Rank implementation currently mixes ties/duplicates detection and rank
computation in a single function `CreateRankings`. This makes it poorly
reusable for other Rank-like functions such as the Percentile Rank function
proposed in GH-45190.
### What changes are included in this PR?
Split duplicates detection into a dedicated function that sets a marker bit
in the sort-indices array (it is private to the Rank implementation, so it is
safe to mutate it).
The rank computation itself (`CreateRankings`) becomes simpler and,
moreover, it does not need to read the input values: it becomes therefore
type-agnostic.
This yields a code size reduction (around 45kB saved on the author's
machine):
* before:
```console
$ size /build/build-release/relwithdebinfo/libarrow.so
text data bss dec hex filename
26072218 353832 2567985 28994035 1ba69f3
/build/build-release/relwithdebinfo/libarrow.so
```
* after:
```console
$ size /build/build-release/relwithdebinfo/libarrow.so
text data bss dec hex filename
26028198 353832 2567985 28950015 1b9bdff
/build/build-release/relwithdebinfo/libarrow.so
```
Rank benchmark results are mostly neutral, though there are slight
improvements on some benchmarks, and slight regressions especially on all-nulls
input.
### Are these changes tested?
Yes, by existing tests.
### Are there any user-facing changes?
No.
* GitHub Issue: #45216
Authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/compute/kernels/vector_rank.cc | 157 +++++++++---------
cpp/src/arrow/compute/kernels/vector_sort.cc | 4 +-
.../arrow/compute/kernels/vector_sort_internal.h | 175 ++++++++-------------
cpp/src/arrow/type_traits.h | 7 +
4 files changed, 145 insertions(+), 198 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc
b/cpp/src/arrow/compute/kernels/vector_rank.cc
index b374862fe6..4fdc83788c 100644
--- a/cpp/src/arrow/compute/kernels/vector_rank.cc
+++ b/cpp/src/arrow/compute/kernels/vector_rank.cc
@@ -28,45 +28,61 @@ namespace {
// ----------------------------------------------------------------------
// Rank implementation
-template <typename ValueSelector,
- typename T = std::decay_t<std::invoke_result_t<ValueSelector,
int64_t>>>
+// A bit that is set in the sort indices when the value at the current sort
index
+// is the same as the value at the previous sort index.
+constexpr uint64_t kDuplicateMask = 1ULL << 63;
+
+constexpr bool NeedsDuplicates(RankOptions::Tiebreaker tiebreaker) {
+ return tiebreaker != RankOptions::First;
+}
+
+template <typename ValueSelector>
+void MarkDuplicates(const NullPartitionResult& sorted, ValueSelector&&
value_selector) {
+ using T = decltype(value_selector(int64_t{}));
+
+ // Process non-nulls
+ if (sorted.non_nulls_end != sorted.non_nulls_begin) {
+ auto it = sorted.non_nulls_begin;
+ T prev_value = value_selector(*it);
+ while (++it < sorted.non_nulls_end) {
+ T curr_value = value_selector(*it);
+ if (curr_value == prev_value) {
+ *it |= kDuplicateMask;
+ }
+ prev_value = curr_value;
+ }
+ }
+
+ // Process nulls
+ if (sorted.nulls_end != sorted.nulls_begin) {
+ // TODO this should be able to distinguish between NaNs and real nulls
(GH-45193)
+ auto it = sorted.nulls_begin;
+ while (++it < sorted.nulls_end) {
+ *it |= kDuplicateMask;
+ }
+ }
+}
+
Result<Datum> CreateRankings(ExecContext* ctx, const NullPartitionResult&
sorted,
const NullPlacement null_placement,
- const RankOptions::Tiebreaker tiebreaker,
- ValueSelector&& value_selector) {
+ const RankOptions::Tiebreaker tiebreaker) {
auto length = sorted.overall_end() - sorted.overall_begin();
ARROW_ASSIGN_OR_RAISE(auto rankings,
MakeMutableUInt64Array(length, ctx->memory_pool()));
auto out_begin = rankings->GetMutableValues<uint64_t>(1);
uint64_t rank;
+ auto is_duplicate = [](uint64_t index) { return (index & kDuplicateMask) !=
0; };
+ auto original_index = [](uint64_t index) { return index & ~kDuplicateMask; };
+
switch (tiebreaker) {
case RankOptions::Dense: {
- T curr_value, prev_value{};
rank = 0;
-
- if (null_placement == NullPlacement::AtStart && sorted.null_count() > 0)
{
- rank++;
- for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
- out_begin[*it] = rank;
- }
- }
-
- for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
- curr_value = value_selector(*it);
- if (it == sorted.non_nulls_begin || curr_value != prev_value) {
- rank++;
- }
-
- out_begin[*it] = rank;
- prev_value = curr_value;
- }
-
- if (null_placement == NullPlacement::AtEnd) {
- rank++;
- for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
- out_begin[*it] = rank;
+ for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
+ if (!is_duplicate(*it)) {
+ ++rank;
}
+ out_begin[original_index(*it)] = rank;
}
break;
}
@@ -74,68 +90,35 @@ Result<Datum> CreateRankings(ExecContext* ctx, const
NullPartitionResult& sorted
case RankOptions::First: {
rank = 0;
for (auto it = sorted.overall_begin(); it < sorted.overall_end(); it++) {
+ // No duplicate marks expected for RankOptions::First
+ DCHECK(!is_duplicate(*it));
out_begin[*it] = ++rank;
}
break;
}
case RankOptions::Min: {
- T curr_value, prev_value{};
rank = 0;
-
- if (null_placement == NullPlacement::AtStart) {
- rank++;
- for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
- out_begin[*it] = rank;
- }
- }
-
- for (auto it = sorted.non_nulls_begin; it < sorted.non_nulls_end; it++) {
- curr_value = value_selector(*it);
- if (it == sorted.non_nulls_begin || curr_value != prev_value) {
+ for (auto it = sorted.overall_begin(); it < sorted.overall_end(); ++it) {
+ if (!is_duplicate(*it)) {
rank = (it - sorted.overall_begin()) + 1;
}
- out_begin[*it] = rank;
- prev_value = curr_value;
- }
-
- if (null_placement == NullPlacement::AtEnd) {
- rank = sorted.non_null_count() + 1;
- for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
- out_begin[*it] = rank;
- }
+ out_begin[original_index(*it)] = rank;
}
break;
}
case RankOptions::Max: {
- // The algorithm for Max is just like Min, but in reverse order.
- T curr_value, prev_value{};
rank = length;
-
- if (null_placement == NullPlacement::AtEnd) {
- for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
- out_begin[*it] = rank;
- }
- }
-
- for (auto it = sorted.non_nulls_end - 1; it >= sorted.non_nulls_begin;
it--) {
- curr_value = value_selector(*it);
-
- if (it == sorted.non_nulls_end - 1 || curr_value != prev_value) {
- rank = (it - sorted.overall_begin()) + 1;
- }
- out_begin[*it] = rank;
- prev_value = curr_value;
- }
-
- if (null_placement == NullPlacement::AtStart) {
- rank = sorted.null_count();
- for (auto it = sorted.nulls_begin; it < sorted.nulls_end; it++) {
- out_begin[*it] = rank;
+ for (auto it = sorted.overall_end() - 1; it >= sorted.overall_begin();
--it) {
+ out_begin[original_index(*it)] = rank;
+ // If the current index isn't marked as duplicate, then it's the last
+ // tie in a row (since we iterate in reverse order), so update rank
+ // for the next row of ties.
+ if (!is_duplicate(*it)) {
+ rank = it - sorted.overall_begin();
}
}
-
break;
}
}
@@ -209,11 +192,14 @@ class Ranker<Array> : public RankerMixin<Array,
Ranker<Array>> {
array_sorter(indices_begin_, indices_end_, array, 0,
ArraySortOptions(order_,
null_placement_), ctx_));
- auto value_selector = [&array](int64_t index) {
- return GetView::LogicalValue(array.GetView(index));
- };
- ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted,
null_placement_,
- tiebreaker_,
value_selector));
+ if (NeedsDuplicates(tiebreaker_)) {
+ auto value_selector = [&array](int64_t index) {
+ return GetView::LogicalValue(array.GetView(index));
+ };
+ MarkDuplicates(sorted, value_selector);
+ }
+ ARROW_ASSIGN_OR_RAISE(*output_,
+ CreateRankings(ctx_, sorted, null_placement_,
tiebreaker_));
return Status::OK();
}
@@ -238,13 +224,16 @@ class Ranker<ChunkedArray> : public
RankerMixin<ChunkedArray, Ranker<ChunkedArra
SortChunkedArray(ctx_, indices_begin_, indices_end_, physical_type_,
physical_chunks_, order_, null_placement_));
- const auto arrays = GetArrayPointers(physical_chunks_);
- auto value_selector = [resolver =
ChunkedArrayResolver(span(arrays))](int64_t index) {
- return resolver.Resolve(index).Value<InType>();
- };
- ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted,
null_placement_,
- tiebreaker_,
value_selector));
-
+ if (NeedsDuplicates(tiebreaker_)) {
+ const auto arrays = GetArrayPointers(physical_chunks_);
+ auto value_selector = [resolver =
+ ChunkedArrayResolver(span(arrays))](int64_t
index) {
+ return resolver.Resolve(index).Value<InType>();
+ };
+ MarkDuplicates(sorted, value_selector);
+ }
+ ARROW_ASSIGN_OR_RAISE(*output_,
+ CreateRankings(ctx_, sorted, null_placement_,
tiebreaker_));
return Status::OK();
}
diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc
b/cpp/src/arrow/compute/kernels/vector_sort.cc
index d81187837d..f9ae69730f 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort.cc
@@ -121,7 +121,7 @@ class ChunkedArraySorter : public TypeVisitor {
CompressedChunkLocation* nulls_middle,
CompressedChunkLocation* nulls_end,
CompressedChunkLocation* temp_indices, int64_t
null_count) {
- if (has_null_like_values<typename ArrayType::TypeClass>::value) {
+ if (has_null_like_values<typename ArrayType::TypeClass>()) {
PartitionNullsOnly<StablePartitioner>(nulls_begin, nulls_end, arrays,
null_count, null_placement_);
}
@@ -781,7 +781,7 @@ class TableSorter {
CompressedChunkLocation* nulls_middle,
CompressedChunkLocation* nulls_end,
CompressedChunkLocation* temp_indices, int64_t null_count) {
- if constexpr (has_null_like_values<ArrowType>::value) {
+ if constexpr (has_null_like_values<ArrowType>()) {
// Merge rows with a null or a null-like in the first sort key
auto& comparator = comparator_;
const auto& first_sort_key = sort_keys_[0];
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h
b/cpp/src/arrow/compute/kernels/vector_sort_internal.h
index 97a2db1d11..cc6b7834a3 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h
+++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h
@@ -29,9 +29,7 @@
#include "arrow/type.h"
#include "arrow/type_traits.h"
-namespace arrow {
-namespace compute {
-namespace internal {
+namespace arrow::compute::internal {
// Visit all physical types for which sorting is implemented.
#define VISIT_SORTABLE_PHYSICAL_TYPES(VISIT) \
@@ -71,49 +69,17 @@ struct StablePartitioner {
}
};
-template <typename TypeClass, typename Enable = void>
-struct NullTraits {
- using has_null_like_values = std::false_type;
-};
-
template <typename TypeClass>
-struct NullTraits<TypeClass, enable_if_physical_floating_point<TypeClass>> {
- using has_null_like_values = std::true_type;
-};
-
-template <typename TypeClass>
-using has_null_like_values = typename
NullTraits<TypeClass>::has_null_like_values;
+constexpr bool has_null_like_values() {
+ return is_physical_floating(TypeClass::type_id);
+}
// Compare two values, taking NaNs into account
-template <typename Type, typename Enable = void>
-struct ValueComparator;
-
-template <typename Type>
-struct ValueComparator<Type, enable_if_t<!has_null_like_values<Type>::value>> {
- template <typename Value>
- static int Compare(const Value& left, const Value& right, SortOrder order,
- NullPlacement null_placement) {
- int compared;
- if (left == right) {
- compared = 0;
- } else if (left > right) {
- compared = 1;
- } else {
- compared = -1;
- }
- if (order == SortOrder::Descending) {
- compared = -compared;
- }
- return compared;
- }
-};
-
-template <typename Type>
-struct ValueComparator<Type, enable_if_t<has_null_like_values<Type>::value>> {
- template <typename Value>
- static int Compare(const Value& left, const Value& right, SortOrder order,
- NullPlacement null_placement) {
+template <typename Type, typename Value>
+int CompareTypeValues(Value&& left, Value&& right, SortOrder order,
+ NullPlacement null_placement) {
+ if constexpr (has_null_like_values<Type>()) {
const bool is_nan_left = std::isnan(left);
const bool is_nan_right = std::isnan(right);
if (is_nan_left && is_nan_right) {
@@ -123,25 +89,19 @@ struct ValueComparator<Type,
enable_if_t<has_null_like_values<Type>::value>> {
} else if (is_nan_right) {
return null_placement == NullPlacement::AtStart ? 1 : -1;
}
- int compared;
- if (left == right) {
- compared = 0;
- } else if (left > right) {
- compared = 1;
- } else {
- compared = -1;
- }
- if (order == SortOrder::Descending) {
- compared = -compared;
- }
- return compared;
}
-};
-
-template <typename Type, typename Value>
-int CompareTypeValues(const Value& left, const Value& right, SortOrder order,
- NullPlacement null_placement) {
- return ValueComparator<Type>::Compare(left, right, order, null_placement);
+ int compared;
+ if (left == right) {
+ compared = 0;
+ } else if (left > right) {
+ compared = 1;
+ } else {
+ compared = -1;
+ }
+ if (order == SortOrder::Descending) {
+ compared = -compared;
+ }
+ return compared;
}
template <typename IndexType>
@@ -238,33 +198,28 @@ NullPartitionResult PartitionNullsOnly(uint64_t*
indices_begin, uint64_t* indice
//
// `offset` is used when this is called on a chunk of a chunked array
template <typename ArrayType, typename Partitioner>
-enable_if_t<!has_null_like_values<typename ArrayType::TypeClass>::value,
- NullPartitionResult>
-PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
- const ArrayType& values, int64_t offset,
- NullPlacement null_placement) {
- return NullPartitionResult::NoNulls(indices_begin, indices_end,
null_placement);
-}
-
-template <typename ArrayType, typename Partitioner>
-enable_if_t<has_null_like_values<typename ArrayType::TypeClass>::value,
- NullPartitionResult>
-PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
- const ArrayType& values, int64_t offset,
- NullPlacement null_placement) {
- Partitioner partitioner;
- if (null_placement == NullPlacement::AtStart) {
- auto null_likes_end =
- partitioner(indices_begin, indices_end, [&values, &offset](uint64_t
ind) {
- return std::isnan(values.GetView(ind - offset));
- });
- return NullPartitionResult::NullsAtStart(indices_begin, indices_end,
null_likes_end);
+NullPartitionResult PartitionNullLikes(uint64_t* indices_begin, uint64_t*
indices_end,
+ const ArrayType& values, int64_t offset,
+ NullPlacement null_placement) {
+ if constexpr (has_null_like_values<typename ArrayType::TypeClass>()) {
+ Partitioner partitioner;
+ if (null_placement == NullPlacement::AtStart) {
+ auto null_likes_end =
+ partitioner(indices_begin, indices_end, [&values, &offset](uint64_t
ind) {
+ return std::isnan(values.GetView(ind - offset));
+ });
+ return NullPartitionResult::NullsAtStart(indices_begin, indices_end,
+ null_likes_end);
+ } else {
+ auto null_likes_begin =
+ partitioner(indices_begin, indices_end, [&values, &offset](uint64_t
ind) {
+ return !std::isnan(values.GetView(ind - offset));
+ });
+ return NullPartitionResult::NullsAtEnd(indices_begin, indices_end,
+ null_likes_begin);
+ }
} else {
- auto null_likes_begin =
- partitioner(indices_begin, indices_end, [&values, &offset](uint64_t
ind) {
- return !std::isnan(values.GetView(ind - offset));
- });
- return NullPartitionResult::NullsAtEnd(indices_begin, indices_end,
null_likes_begin);
+ return NullPartitionResult::NoNulls(indices_begin, indices_end,
null_placement);
}
}
@@ -344,32 +299,30 @@ ChunkedNullPartitionResult
PartitionNullsOnly(CompressedChunkLocation* locations
}
}
-template <typename ArrayType, typename Partitioner>
-enable_if_t<!has_null_like_values<typename ArrayType::TypeClass>::value,
- NullPartitionResult>
-PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
- const ChunkedArrayResolver& resolver, NullPlacement
null_placement) {
- return NullPartitionResult::NoNulls(indices_begin, indices_end,
null_placement);
-}
-
template <typename ArrayType, typename Partitioner,
typename TypeClass = typename ArrayType::TypeClass>
-enable_if_t<has_null_like_values<TypeClass>::value, NullPartitionResult>
-PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
- const ChunkedArrayResolver& resolver, NullPlacement
null_placement) {
- Partitioner partitioner;
- if (null_placement == NullPlacement::AtStart) {
- auto null_likes_end = partitioner(indices_begin, indices_end, [&](uint64_t
ind) {
- const auto chunk = resolver.Resolve(ind);
- return std::isnan(chunk.Value<TypeClass>());
- });
- return NullPartitionResult::NullsAtStart(indices_begin, indices_end,
null_likes_end);
+NullPartitionResult PartitionNullLikes(uint64_t* indices_begin, uint64_t*
indices_end,
+ const ChunkedArrayResolver& resolver,
+ NullPlacement null_placement) {
+ if constexpr (has_null_like_values<typename ArrayType::TypeClass>()) {
+ Partitioner partitioner;
+ if (null_placement == NullPlacement::AtStart) {
+ auto null_likes_end = partitioner(indices_begin, indices_end,
[&](uint64_t ind) {
+ const auto chunk = resolver.Resolve(ind);
+ return std::isnan(chunk.Value<TypeClass>());
+ });
+ return NullPartitionResult::NullsAtStart(indices_begin, indices_end,
+ null_likes_end);
+ } else {
+ auto null_likes_begin = partitioner(indices_begin, indices_end,
[&](uint64_t ind) {
+ const auto chunk = resolver.Resolve(ind);
+ return !std::isnan(chunk.Value<TypeClass>());
+ });
+ return NullPartitionResult::NullsAtEnd(indices_begin, indices_end,
+ null_likes_begin);
+ }
} else {
- auto null_likes_begin = partitioner(indices_begin, indices_end,
[&](uint64_t ind) {
- const auto chunk = resolver.Resolve(ind);
- return !std::isnan(chunk.Value<TypeClass>());
- });
- return NullPartitionResult::NullsAtEnd(indices_begin, indices_end,
null_likes_begin);
+ return NullPartitionResult::NoNulls(indices_begin, indices_end,
null_placement);
}
}
@@ -853,6 +806,4 @@ inline Result<std::shared_ptr<ArrayData>>
MakeMutableUInt64Array(
return ArrayData::Make(uint64(), length, {nullptr, std::move(data)},
/*null_count=*/0);
}
-} // namespace internal
-} // namespace compute
-} // namespace arrow
+} // namespace arrow::compute::internal
diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h
index 92009c8560..6ed495dcb2 100644
--- a/cpp/src/arrow/type_traits.h
+++ b/cpp/src/arrow/type_traits.h
@@ -1069,6 +1069,13 @@ constexpr bool is_floating(Type::type type_id) {
return false;
}
+/// \brief Check for a physical floating point type
+///
+/// This predicate matches floating-point types, except half-float.
+constexpr bool is_physical_floating(Type::type type_id) {
+ return is_floating(type_id) && type_id != Type::HALF_FLOAT;
+}
+
/// \brief Check for a numeric type
///
/// This predicate doesn't match decimals (see `is_decimal`).