kou commented on a change in pull request #8612:
URL: https://github.com/apache/arrow/pull/8612#discussion_r526617675
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -346,14 +374,396 @@ void AddSortingKernels(VectorKernel base,
VectorFunction* func) {
}
}
+class TableSorter : public TypeVisitor {
+ private:
+ struct ResolvedSortKey {
+ ResolvedSortKey(const ChunkedArray& chunked_array, const SortOrder order)
+ : order(order) {
+ type = chunked_array.type().get();
+ null_count = chunked_array.null_count();
+ num_chunks = chunked_array.num_chunks();
+ for (const auto& chunk : chunked_array.chunks()) {
+ chunks.push_back(chunk.get());
+ }
+ }
+
+ template <typename ArrayType>
+ ArrayType* ResolveChunk(int64_t index, int64_t& chunk_index) const {
+ if (num_chunks == 1) {
+ chunk_index = index;
+ return static_cast<ArrayType*>(chunks[0]);
+ } else {
+ int64_t offset = 0;
+ for (size_t i = 0; i < num_chunks; ++i) {
+ if (index < offset + chunks[i]->length()) {
+ chunk_index = index - offset;
+ return static_cast<ArrayType*>(chunks[i]);
+ }
+ offset += chunks[i]->length();
+ }
+ return nullptr;
+ }
+ }
+
+ SortOrder order;
+ DataType* type;
+ int64_t null_count;
+ size_t num_chunks;
+ std::vector<Array*> chunks;
+ };
+
+ class Comparer : public TypeVisitor {
+ public:
+ Comparer(const Table& table, const std::vector<SortKey>& sort_keys)
+ : TypeVisitor(), status_(Status::OK()) {
+ for (const auto& sort_key : sort_keys) {
+ const auto& chunked_array = table.GetColumnByName(sort_key.name);
+ if (!chunked_array) {
+ status_ = Status::Invalid("Nonexistent sort key column: ",
sort_key.name);
+ return;
+ }
+ sort_keys_.emplace_back(*chunked_array, sort_key.order);
+ }
+ }
+
+ Status status() { return status_; }
+
+ const std::vector<ResolvedSortKey>& sort_keys() { return sort_keys_; }
+
+ bool Compare(uint64_t left, uint64_t right, size_t start_sort_key_index) {
+ current_left_ = left;
+ current_right_ = right;
+ current_compared_ = 0;
+ auto num_sort_keys = sort_keys_.size();
+ for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) {
+ current_sort_key_index_ = i;
+ status_ = sort_keys_[i].type->Accept(this);
+ if (current_compared_ != 0) {
+ break;
+ }
+ }
+ return current_compared_ < 0;
+ }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE##Type& type) override { \
+ current_compared_ = CompareType<TYPE##Type>(); \
+ return Status::OK(); \
+ }
+
+ VISIT(Int8)
+ VISIT(Int16)
+ VISIT(Int32)
+ VISIT(Int64)
+ VISIT(UInt8)
+ VISIT(UInt16)
+ VISIT(UInt32)
+ VISIT(UInt64)
+ VISIT(Float)
+ VISIT(Double)
+ VISIT(String)
+ VISIT(Binary)
+ VISIT(LargeString)
+ VISIT(LargeBinary)
+
+#undef VISIT
+
+ private:
+ template <typename Type>
+ int32_t CompareType() {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ const auto& sort_key = sort_keys_[current_sort_key_index_];
+ auto order = sort_key.order;
+ int64_t index_left = 0;
+ auto array_left = sort_key.ResolveChunk<ArrayType>(current_left_,
index_left);
+ int64_t index_right = 0;
+ auto array_right = sort_key.ResolveChunk<ArrayType>(current_right_,
index_right);
+ if (sort_key.null_count > 0) {
+ auto is_null_left = array_left->IsNull(index_left);
+ auto is_null_right = array_right->IsNull(index_right);
+ if (is_null_left && is_null_right) {
+ return 0;
+ } else if (is_null_left) {
+ return 1;
+ } else if (is_null_right) {
+ return -1;
+ }
+ }
+ auto left = array_left->GetView(index_left);
+ auto right = array_right->GetView(index_right);
+ int32_t compared;
+ if (left == right) {
+ compared = 0;
+ } else if (left > right) {
+ compared = 1;
+ } else {
+ compared = -1;
+ }
+ if (order == SortOrder::DESCENDING) {
+ compared = -compared;
+ }
+ return compared;
+ }
+
+ Status status_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ int64_t current_left_;
+ int64_t current_right_;
+ size_t current_sort_key_index_;
+ int32_t current_compared_;
+ };
+
+ public:
+ TableSorter(uint64_t* indices_begin, uint64_t* indices_end, const Table&
table,
+ const SortOptions& options)
+ : indices_begin_(indices_begin),
+ indices_end_(indices_end),
+ comparer_(table, options.sort_keys) {}
+
+ Status Sort() {
+ ARROW_RETURN_NOT_OK(comparer_.status());
+ return comparer_.sort_keys()[0].type->Accept(this);
+ }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE##Type& type) override { return
SortInternal<TYPE##Type>(); }
+
+ VISIT(Int8)
+ VISIT(Int16)
+ VISIT(Int32)
+ VISIT(Int64)
+ VISIT(UInt8)
+ VISIT(UInt16)
+ VISIT(UInt32)
+ VISIT(UInt64)
+ VISIT(Float)
+ VISIT(Double)
+ VISIT(String)
+ VISIT(Binary)
+ VISIT(LargeString)
+ VISIT(LargeBinary)
+
+#undef VISIT
+
+ private:
+ template <typename Type>
+ Status SortInternal() {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ std::iota(indices_begin_, indices_end_, 0);
+
+ auto& comparer = comparer_;
+ const auto& first_sort_key = comparer.sort_keys()[0];
+ auto nulls_begin = indices_end_;
+ nulls_begin = PartitionNullsInternal<Type>(first_sort_key);
+ std::stable_sort(
+ indices_begin_, nulls_begin,
+ [&first_sort_key, &comparer](uint64_t left, uint64_t right) {
+ int64_t index_left = 0;
+ auto array_left = first_sort_key.ResolveChunk<ArrayType>(left,
index_left);
+ int64_t index_right = 0;
+ auto array_right = first_sort_key.ResolveChunk<ArrayType>(right,
index_right);
+ auto value_left = array_left->GetView(index_left);
+ auto value_right = array_right->GetView(index_right);
+ if (value_left == value_right) {
+ return comparer.Compare(left, right, 1);
+ } else {
+ auto compared = value_left < value_right;
+ if (first_sort_key.order == SortOrder::ASCENDING) {
+ return compared;
+ } else {
+ return !compared;
+ }
+ }
+ });
+ return Status::OK();
+ }
+
+ template <typename Type>
+ enable_if_t<!is_floating_type<Type>::value, uint64_t*>
PartitionNullsInternal(
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ if (first_sort_key.null_count == 0) {
+ return indices_end_;
+ }
+ StablePartitioner partitioner;
+ auto nulls_begin =
+ partitioner(indices_begin_, indices_end_, [&first_sort_key](uint64_t
index) {
+ int64_t index_chunk = 0;
+ auto chunk = first_sort_key.ResolveChunk<ArrayType>(index,
index_chunk);
+ return !chunk->IsNull(index_chunk);
+ });
+ auto& comparer = comparer_;
+ std::stable_sort(nulls_begin, indices_end_,
+ [&comparer](uint64_t left, uint64_t right) {
+ return comparer.Compare(left, right, 1);
+ });
+ return nulls_begin;
+ }
+
+ template <typename Type>
+ enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+ const ResolvedSortKey& first_sort_key) {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ StablePartitioner partitioner;
+ if (first_sort_key.null_count == 0) {
+ return partitioner(indices_begin_, indices_end_,
[&first_sort_key](uint64_t index) {
+ int64_t index_chunk = 0;
+ auto chunk = first_sort_key.ResolveChunk<ArrayType>(index,
index_chunk);
+ return !std::isnan(chunk->GetView(index_chunk));
+ });
+ }
+ auto nans_and_nulls_begin =
+ partitioner(indices_begin_, indices_end_, [&first_sort_key](uint64_t
index) {
+ int64_t index_chunk = 0;
+ auto chunk = first_sort_key.ResolveChunk<ArrayType>(index,
index_chunk);
+ return !chunk->IsNull(index_chunk) &&
!std::isnan(chunk->GetView(index_chunk));
+ });
+ auto nulls_begin = nans_and_nulls_begin;
+ if (first_sort_key.null_count < static_cast<int64_t>(indices_end_ -
nulls_begin)) {
+ // move Nulls after NaN
+ nulls_begin = partitioner(
+ nans_and_nulls_begin, indices_end_, [&first_sort_key](uint64_t
index) {
+ int64_t index_chunk = 0;
+ auto chunk = first_sort_key.ResolveChunk<ArrayType>(index,
index_chunk);
+ return !chunk->IsNull(index_chunk);
+ });
+ }
+ auto& comparer = comparer_;
+ if (nans_and_nulls_begin != nulls_begin) {
+ std::stable_sort(nans_and_nulls_begin, nulls_begin,
+ [&comparer](uint64_t left, uint64_t right) {
+ return comparer.Compare(left, right, 1);
+ });
+ }
+ std::stable_sort(nulls_begin, indices_end_,
+ [&comparer](uint64_t left, uint64_t right) {
+ return comparer.Compare(left, right, 1);
+ });
+ return nans_and_nulls_begin;
+ }
+
+ uint64_t* indices_begin_;
+ uint64_t* indices_end_;
+ Comparer comparer_;
+};
+
const FunctionDoc sort_indices_doc(
+ "Return the indices that would sort an array, record batch or table",
+ ("This function computes an array of indices that define a stable sort\n"
+ "of the input array, record batch or table. Null values are considered\n"
+ "greater than any other value and are therefore sorted at the end of
the\n"
+ "input. For floating-point types, NaNs are considered greater than any\n"
+ "other non-null value, but smaller than null values."),
+ {"input"}, "SortOptions");
+
+class SortIndicesMetaFunction : public MetaFunction {
+ public:
+ SortIndicesMetaFunction()
+ : MetaFunction("sort_indices", Arity::Unary(), &sort_indices_doc) {}
+
+ Result<Datum> ExecuteImpl(const std::vector<Datum>& args,
+ const FunctionOptions* options,
+ ExecContext* ctx) const override {
+ const SortOptions& sort_options = static_cast<const
SortOptions&>(*options);
+ switch (args[0].kind()) {
+ case Datum::ARRAY:
+ return SortIndices(*args[0].make_array(), sort_options, ctx);
+ break;
+ case Datum::CHUNKED_ARRAY:
+ return SortIndices(*args[0].chunked_array(), sort_options, ctx);
+ break;
+ case Datum::RECORD_BATCH: {
+ ARROW_ASSIGN_OR_RAISE(auto table,
+
Table::FromRecordBatches({args[0].record_batch()}));
+ return SortIndices(*table, sort_options, ctx);
+ } break;
+ case Datum::TABLE:
+ return SortIndices(*args[0].table(), sort_options, ctx);
+ break;
+ default:
+ break;
+ }
+ return Status::NotImplemented(
+ "Unsupported types for sort_indices operation: "
+ "values=",
+ args[0].ToString());
+ }
+
+ private:
+ Result<std::shared_ptr<Array>> SortIndices(const Array& values,
+ const SortOptions& options,
+ ExecContext* ctx) const {
+ SortOrder order = SortOrder::ASCENDING;
+ if (!options.sort_keys.empty()) {
+ order = options.sort_keys[0].order;
+ }
+ ArraySortOptions array_options(order);
+ ARROW_ASSIGN_OR_RAISE(
+ Datum result, CallFunction("array_sort_indices", {values},
&array_options, ctx));
+ return result.make_array();
+ }
+
+ Result<std::shared_ptr<Array>> SortIndices(const ChunkedArray& values,
+ const SortOptions& options,
+ ExecContext* ctx) const {
+ SortOrder order = SortOrder::ASCENDING;
+ if (!options.sort_keys.empty()) {
+ order = options.sort_keys[0].order;
+ }
+ ArraySortOptions array_options(order);
+
+ std::shared_ptr<Array> array_values;
+ if (values.num_chunks() == 1) {
+ array_values = values.chunk(0);
+ } else {
+ ARROW_ASSIGN_OR_RAISE(array_values,
+ Concatenate(values.chunks(), ctx->memory_pool()));
Review comment:
Ah, I forgot to mention this.
I wanted to make ChunkedArray case as follow-up task. I wanted to focus to
Table in this pull request.
But I've implemented naive ChunkedArray support that doesn't use
`Concatenate()` in this pull request. It sorts each chunk and then merge them
but it doesn't use threads.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]