hamilton-earthscope commented on code in PR #749: URL: https://github.com/apache/arrow-go/pull/749#discussion_r3048332046
########## arrow/compute/internal/kernels/vector_sort.go: ########## @@ -0,0 +1,476 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build go1.18 + +package kernels + +import ( + "fmt" + "slices" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/compute/exec" +) + +// SortOrder specifies the sort order for sorting operations. +type SortOrder int8 + +const ( + Ascending SortOrder = iota + Descending +) + +// NullPlacement specifies where null values should be placed in the sort order. +type NullPlacement int8 + +const ( + NullsAtEnd NullPlacement = iota + NullsAtStart +) + +// SortOptions defines options for the sort_indices function. +type SortOptions struct { + Order SortOrder `compute:"order"` + NullPlacement NullPlacement `compute:"null_placement"` +} + +func (SortOptions) TypeName() string { return "SortOptions" } + +type SortState = SortOptions + +// SortKey defines a column to sort by with its ordering and null placement options. +type SortKey struct { + ColumnIndex int + Order SortOrder + NullPlacement NullPlacement +} + +// Chunk-aware sort_indices: logical row IDs 0..n-1, no chunk concatenation. Structure follows +// Apache Arrow C++ vector_sort.cc / vector_sort_internal.h (see vector_sort_internal.go). +// +// Per-column data uses a dense logicalRowMap for O(1) chunk resolution under random compares. +// Each sortable physical type has a dedicated comparator struct in vector_sort_physical.go (C++ +// ConcreteColumnComparator<T> shape): full monomorphization for the hot compare path; no +// value-level compare func pointer. +// compareRowsForKey implements the same ordering as C++ (null placement, NaN, Order). +// +// Single key: ChunkedArraySorter — arraySortOneColumnRange per chunk (PartitionNullsOnly / +// PartitionNullLikes + stable_sort finites), then pairwise merge (ChunkedMergeImpl-style merge +// using full row order; C++ splits null / non-null merge when the type has null-likes). +// +// Multi-key, aligned chunks: TableSorter — per-chunk RadixRecordBatchSorter or +// MultipleKeyRecordBatchSorter, then merge. +// +// Multi-key, single segment: RadixRecordBatchSorter (<= maxRadixSortKeys) or +// MultipleKeyRecordBatchSorter (> maxRadixSortKeys). + +// maxRadixSortKeys matches Arrow C++ kMaxRadixSortKeys (vector_sort.cc): above this, one global +// multi-key stable sort is used instead of MSD radix. +const maxRadixSortKeys = 8 + +// columnComparator is the Go analogue of compute::internal::ColumnComparator (vector_sort_internal.h): +// per-column row compare + null / null-like metadata for partitioning. +type columnComparator interface { + // compareRowsForKey returns -1 if i before j, +1 if i after j, 0 if tied on this column + // (both null, or both non-null and equal), so the caller may advance to the next sort key. + compareRowsForKey(i, j uint64, key SortKey) int + // isNullAt returns true if the global row index is null. + isNullAt(global uint64) bool + // hasNullLikeValues returns true if the column has null-like values. + hasNullLikeValues() bool + // isNullLikeAt returns true if the global row index is a null-like value. + isNullLikeAt(global uint64) bool + // columnHasValidityNulls mirrors Array::null_count() != 0; when false, C++ skips PartitionNullsOnly. + columnHasValidityNulls() bool +} + +// multiColumnComparator compares two logical rows (global uint64 indices) lexicographically +// across every sort key. That matches C++ MultipleKeyComparator::CompareInternal(left, right, 0) +// (vector_sort_internal.h), but it is not a port of the whole MultipleKeyComparator type: C++ keeps +// ResolvedSortKey per column, uses Location (int64 batch row vs ChunkLocation on tables), builds +// virtual ColumnComparator instances, and passes start_sort_key_index for radix tails and other +// partial key ranges — in Go those suffix compares are makeTailComparator(comparators, keys, from). +type multiColumnComparator struct { + columns []columnComparator + keys []SortKey +} + +// compare is a three-way ordering for stable sort / merge: negative if idxA before idxB, etc. +func (m *multiColumnComparator) compare(idxA, idxB uint64) int { + for i, key := range m.keys { + if cmpVal := m.columns[i].compareRowsForKey(idxA, idxB, key); cmpVal != 0 { + return cmpVal + } + } + return 0 +} + +func extensionStorageFixedSizeBinaryChunks(chunks []arrow.Array) ([]arrow.Array, error) { + out := make([]arrow.Array, len(chunks)) + for i, ch := range chunks { + ext, ok := ch.(array.ExtensionArray) + if !ok { + return nil, fmt.Errorf("%w: extension column must implement array.ExtensionArray", arrow.ErrInvalid) + } + st := ext.Storage() + if st.DataType().ID() != arrow.FIXED_SIZE_BINARY { + return nil, fmt.Errorf("%w: sorting extension columns is only supported when storage is fixed_size_binary (got %s)", + arrow.ErrNotImplemented, st.DataType()) + } + out[i] = st + } + return out, nil +} + +func newFixedSizeBinaryComparator(chunks []arrow.Array, numRows int, vn bool) (columnComparator, error) { + f0, ok := chunks[0].(*array.FixedSizeBinary) + if !ok { + return nil, fmt.Errorf("%w: expected *array.FixedSizeBinary chunk", arrow.ErrInvalid) + } + w := f0.DataType().(*arrow.FixedSizeBinaryType).ByteWidth + for i := 1; i < len(chunks); i++ { + fi, ok := chunks[i].(*array.FixedSizeBinary) + if !ok { + return nil, fmt.Errorf("%w: expected *array.FixedSizeBinary chunk", arrow.ErrInvalid) + } + wi := fi.DataType().(*arrow.FixedSizeBinaryType).ByteWidth + if wi != w { + return nil, fmt.Errorf("%w: fixed_size_binary chunks must have the same byte width (%d vs %d)", + arrow.ErrInvalid, w, wi) + } + } + return newPhysicalSortFixedSizeBinaryColumn(chunks, numRows, vn), nil +} + +// createChunkedComparator builds a column comparator for these chunks (one Arrow type for all chunks). +func createChunkedComparator(chunks []arrow.Array, numRows int) (columnComparator, error) { + if len(chunks) == 0 { + return nil, fmt.Errorf("%w: cannot create comparator for empty chunk list", arrow.ErrInvalid) + } + if totalChunkRows(chunks) != numRows { + return nil, fmt.Errorf("%w: chunk row count does not match column length", arrow.ErrInvalid) + } + + validityNulls := chunksHaveNulls(chunks) + typeID := chunks[0].DataType().ID() + switch typeID { + case arrow.INT8: + return newPhysicalSortInt8Column(chunks, numRows, validityNulls), nil + case arrow.INT16: + return newPhysicalSortInt16Column(chunks, numRows, validityNulls), nil + case arrow.INT32: + return newPhysicalSortInt32Column(chunks, numRows, validityNulls), nil + case arrow.DATE32: + return newPhysicalSortDate32Column(chunks, numRows, validityNulls), nil + case arrow.TIME32: + return newPhysicalSortTime32Column(chunks, numRows, validityNulls), nil + case arrow.INT64: + return newPhysicalSortInt64Column(chunks, numRows, validityNulls), nil + case arrow.DATE64: + return newPhysicalSortDate64Column(chunks, numRows, validityNulls), nil + case arrow.TIME64: + return newPhysicalSortTime64Column(chunks, numRows, validityNulls), nil + case arrow.TIMESTAMP: + return newPhysicalSortTimestampColumn(chunks, numRows, validityNulls), nil + case arrow.DURATION: + return newPhysicalSortDurationColumn(chunks, numRows, validityNulls), nil + case arrow.UINT8: + return newPhysicalSortUint8Column(chunks, numRows, validityNulls), nil + case arrow.UINT16: + return newPhysicalSortUint16Column(chunks, numRows, validityNulls), nil + case arrow.UINT32: + return newPhysicalSortUint32Column(chunks, numRows, validityNulls), nil + case arrow.UINT64: + return newPhysicalSortUint64Column(chunks, numRows, validityNulls), nil + case arrow.FLOAT16: + return newPhysicalSortFloat16Column(chunks, numRows, validityNulls), nil + case arrow.FLOAT32: + return newPhysicalSortFloat32Column(chunks, numRows, validityNulls), nil + case arrow.FLOAT64: + return newPhysicalSortFloat64Column(chunks, numRows, validityNulls), nil + case arrow.DECIMAL32: + return newPhysicalSortDecimal32Column(chunks, numRows, validityNulls), nil + case arrow.DECIMAL64: + return newPhysicalSortDecimal64Column(chunks, numRows, validityNulls), nil + case arrow.DECIMAL128: + return newPhysicalSortDecimal128Column(chunks, numRows, validityNulls), nil + case arrow.DECIMAL256: + return newPhysicalSortDecimal256Column(chunks, numRows, validityNulls), nil + case arrow.INTERVAL_MONTHS: + return newPhysicalSortMonthIntervalColumn(chunks, numRows, validityNulls), nil + case arrow.INTERVAL_DAY_TIME: + return newPhysicalSortDayTimeColumn(chunks, numRows, validityNulls), nil + case arrow.INTERVAL_MONTH_DAY_NANO: + return newPhysicalSortMonthDayNanoColumn(chunks, numRows, validityNulls), nil + case arrow.BOOL: + return newPhysicalSortBoolColumn(chunks, numRows, validityNulls), nil + case arrow.STRING: + return newPhysicalSortStringColumn(chunks, numRows, validityNulls), nil + case arrow.LARGE_STRING: + return newPhysicalSortLargeStringColumn(chunks, numRows, validityNulls), nil + case arrow.BINARY: + return newPhysicalSortBinaryColumn(chunks, numRows, validityNulls), nil + case arrow.LARGE_BINARY: + return newPhysicalSortLargeBinaryColumn(chunks, numRows, validityNulls), nil + case arrow.FIXED_SIZE_BINARY: + return newFixedSizeBinaryComparator(chunks, numRows, validityNulls) + case arrow.EXTENSION: + storageChunks, err := extensionStorageFixedSizeBinaryChunks(chunks) + if err != nil { + return nil, err + } + return newFixedSizeBinaryComparator(storageChunks, numRows, validityNulls) + default: + return nil, fmt.Errorf("%w: sorting not supported for type %s", arrow.ErrNotImplemented, typeID) + } +} + +// chunkIndexSpan represents a contiguous range of indices in the global order. +type chunkIndexSpan struct { + lo, hi int +} + +// mergeAdjacentStable merges sorted adjacent ranges [a0,a1) and [b0,b1) (a1 == b0) into indices[lo:hi] +// using a strict weak order: i is ordered before j iff less(i,j). Tie-breaking prefers the left range +// (stable merge, same as C++ std::merge with !comp(right,left)). +func mergeAdjacentStable(indices, tmp []uint64, a0, a1, b0, b1 int, less func(a, b uint64) bool) { + i, j, k := a0, b0, a0 + for i < a1 && j < b1 { + if !less(indices[j], indices[i]) { + tmp[k] = indices[i] + i++ + } else { + tmp[k] = indices[j] + j++ + } + k++ + } + for i < a1 { + tmp[k] = indices[i] + k++ + i++ + } + for j < b1 { + tmp[k] = indices[j] + k++ + j++ + } + copy(indices[a0:b1], tmp[a0:b1]) +} + +// pairwiseMergeSortedSpans merges already-sorted adjacent index spans (chunk batch rows in global +// order), matching Arrow C++ ChunkedMergeImpl / TableSorter batch merge (vector_sort.cc). +// spanScratch must have capacity >= len(spans); it ping-pongs with spans' backing during merging. +func pairwiseMergeSortedSpans(indices, tmp []uint64, spans []chunkIndexSpan, less func(a, b uint64) bool, spanScratch []chunkIndexSpan) { + if len(spans) <= 1 { + return + } + if cap(spanScratch) < len(spans) { + panic("kernels: spanScratch cap < len(spans)") + } + cur := spans + other := spanScratch[:0] + for len(cur) > 1 { + other = other[:0] + for i := 0; i < len(cur); i += 2 { + if i+1 < len(cur) { + s0, s1 := cur[i], cur[i+1] + mergeAdjacentStable(indices, tmp, s0.lo, s0.hi, s1.lo, s1.hi, less) + other = append(other, chunkIndexSpan{s0.lo, s1.hi}) + } else { + other = append(other, cur[i]) + } + } + cur, other = other, cur + } +} + +// alignedChunkBoundaries reports cumulative row offsets for chunk boundaries when every sort column +// has the same chunk count and matching chunk lengths (typical for Arrow tables). +func alignedChunkBoundaries(columns []*arrow.Chunked) ([]int, bool) { + if len(columns) == 0 { + return nil, false + } + ch0 := columns[0].Chunks() + n := len(ch0) + if n == 0 { + return nil, false + } + offs := make([]int, n+1) + for i := 0; i < n; i++ { Review Comment: Does that mean `arrow/compute/registry.go` as well? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
