kou commented on a change in pull request #8612:
URL: https://github.com/apache/arrow/pull/8612#discussion_r526619466
##########
File path: cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc
##########
@@ -56,17 +57,99 @@ static void SortToIndicesInt64Compare(benchmark::State&
state) {
auto max = std::numeric_limits<int64_t>::max();
auto values = rand.Int64(array_size, min, max, args.null_proportion);
- SortToIndicesBenchmark(state, values);
+ ArraySortIndicesBenchmark(state, values);
}
-BENCHMARK(SortToIndicesInt64Count)
+static void TableSortIndicesBenchmark(benchmark::State& state,
+ const std::shared_ptr<Table>& table,
+ const SortOptions& options) {
+ for (auto _ : state) {
+ ABORT_NOT_OK(SortIndices(*table, options).status());
+ }
+ state.SetItemsProcessed(state.iterations() * table->num_rows());
+}
+
+static void TableSortIndicesInt64Count(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+ auto values = rand.Int64(array_size, -100, 100, args.null_proportion);
+ std::vector<std::shared_ptr<Field>> fields = {{field("int64", int64())}};
+ auto table = Table::Make(schema(fields), {values}, array_size);
+ SortOptions options({SortKey("int64", SortOrder::ASCENDING)});
+
+ TableSortIndicesBenchmark(state, table, options);
+}
+
+static void TableSortIndicesInt64Compare(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+
+ auto min = std::numeric_limits<int64_t>::min();
+ auto max = std::numeric_limits<int64_t>::max();
+ auto values = rand.Int64(array_size, min, max, args.null_proportion);
+ std::vector<std::shared_ptr<Field>> fields = {{field("int64", int64())}};
+ auto table = Table::Make(schema(fields), {values}, array_size);
+ SortOptions options({SortKey("int64", SortOrder::ASCENDING)});
+
+ TableSortIndicesBenchmark(state, table, options);
+}
+
+static void TableSortIndicesInt64Int64(benchmark::State& state) {
+ RegressionArgs args(state);
+
+ const int64_t array_size = args.size / sizeof(int64_t);
+ auto rand = random::RandomArrayGenerator(kSeed);
+
+ auto min = std::numeric_limits<int64_t>::min();
+ auto max = std::numeric_limits<int64_t>::max();
Review comment:
Yes.
I've added narrow range cases that have many duplicate values.
##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -346,14 +374,396 @@ void AddSortingKernels(VectorKernel base,
VectorFunction* func) {
}
}
+class TableSorter : public TypeVisitor {
+ private:
+ struct ResolvedSortKey {
+ ResolvedSortKey(const ChunkedArray& chunked_array, const SortOrder order)
+ : order(order) {
+ type = chunked_array.type().get();
+ null_count = chunked_array.null_count();
+ num_chunks = chunked_array.num_chunks();
+ for (const auto& chunk : chunked_array.chunks()) {
+ chunks.push_back(chunk.get());
+ }
+ }
+
+ template <typename ArrayType>
+ ArrayType* ResolveChunk(int64_t index, int64_t& chunk_index) const {
+ if (num_chunks == 1) {
+ chunk_index = index;
+ return static_cast<ArrayType*>(chunks[0]);
+ } else {
+ int64_t offset = 0;
+ for (size_t i = 0; i < num_chunks; ++i) {
+ if (index < offset + chunks[i]->length()) {
+ chunk_index = index - offset;
+ return static_cast<ArrayType*>(chunks[i]);
+ }
+ offset += chunks[i]->length();
+ }
+ return nullptr;
+ }
+ }
+
+ SortOrder order;
+ DataType* type;
+ int64_t null_count;
+ size_t num_chunks;
+ std::vector<Array*> chunks;
+ };
+
+ class Comparer : public TypeVisitor {
+ public:
+ Comparer(const Table& table, const std::vector<SortKey>& sort_keys)
+ : TypeVisitor(), status_(Status::OK()) {
+ for (const auto& sort_key : sort_keys) {
+ const auto& chunked_array = table.GetColumnByName(sort_key.name);
+ if (!chunked_array) {
+ status_ = Status::Invalid("Nonexistent sort key column: ",
sort_key.name);
+ return;
+ }
+ sort_keys_.emplace_back(*chunked_array, sort_key.order);
+ }
+ }
+
+ Status status() { return status_; }
+
+ const std::vector<ResolvedSortKey>& sort_keys() { return sort_keys_; }
+
+ bool Compare(uint64_t left, uint64_t right, size_t start_sort_key_index) {
+ current_left_ = left;
+ current_right_ = right;
+ current_compared_ = 0;
+ auto num_sort_keys = sort_keys_.size();
+ for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) {
+ current_sort_key_index_ = i;
+ status_ = sort_keys_[i].type->Accept(this);
+ if (current_compared_ != 0) {
+ break;
+ }
+ }
+ return current_compared_ < 0;
+ }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE##Type& type) override { \
+ current_compared_ = CompareType<TYPE##Type>(); \
+ return Status::OK(); \
+ }
+
+ VISIT(Int8)
+ VISIT(Int16)
+ VISIT(Int32)
+ VISIT(Int64)
+ VISIT(UInt8)
+ VISIT(UInt16)
+ VISIT(UInt32)
+ VISIT(UInt64)
+ VISIT(Float)
+ VISIT(Double)
+ VISIT(String)
+ VISIT(Binary)
+ VISIT(LargeString)
+ VISIT(LargeBinary)
+
+#undef VISIT
+
+ private:
+ template <typename Type>
+ int32_t CompareType() {
+ using ArrayType = typename TypeTraits<Type>::ArrayType;
+ const auto& sort_key = sort_keys_[current_sort_key_index_];
+ auto order = sort_key.order;
+ int64_t index_left = 0;
+ auto array_left = sort_key.ResolveChunk<ArrayType>(current_left_,
index_left);
+ int64_t index_right = 0;
+ auto array_right = sort_key.ResolveChunk<ArrayType>(current_right_,
index_right);
+ if (sort_key.null_count > 0) {
+ auto is_null_left = array_left->IsNull(index_left);
+ auto is_null_right = array_right->IsNull(index_right);
+ if (is_null_left && is_null_right) {
+ return 0;
+ } else if (is_null_left) {
+ return 1;
+ } else if (is_null_right) {
+ return -1;
+ }
+ }
+ auto left = array_left->GetView(index_left);
+ auto right = array_right->GetView(index_right);
+ int32_t compared;
+ if (left == right) {
+ compared = 0;
+ } else if (left > right) {
+ compared = 1;
+ } else {
+ compared = -1;
+ }
+ if (order == SortOrder::DESCENDING) {
+ compared = -compared;
+ }
+ return compared;
+ }
+
+ Status status_;
+ std::vector<ResolvedSortKey> sort_keys_;
+ int64_t current_left_;
+ int64_t current_right_;
+ size_t current_sort_key_index_;
+ int32_t current_compared_;
+ };
+
+ public:
+ TableSorter(uint64_t* indices_begin, uint64_t* indices_end, const Table&
table,
+ const SortOptions& options)
+ : indices_begin_(indices_begin),
+ indices_end_(indices_end),
+ comparer_(table, options.sort_keys) {}
+
+ Status Sort() {
+ ARROW_RETURN_NOT_OK(comparer_.status());
+ return comparer_.sort_keys()[0].type->Accept(this);
+ }
+
+#define VISIT(TYPE) \
+ Status Visit(const TYPE##Type& type) override { return
SortInternal<TYPE##Type>(); }
+
+ VISIT(Int8)
+ VISIT(Int16)
+ VISIT(Int32)
+ VISIT(Int64)
+ VISIT(UInt8)
+ VISIT(UInt16)
+ VISIT(UInt32)
+ VISIT(UInt64)
+ VISIT(Float)
+ VISIT(Double)
+ VISIT(String)
+ VISIT(Binary)
+ VISIT(LargeString)
+ VISIT(LargeBinary)
+
+#undef VISIT
+
+ private:
+ template <typename Type>
+ Status SortInternal() {
Review comment:
Added.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]