Copilot commented on code in PR #49679: URL: https://github.com/apache/arrow/pull/49679#discussion_r3506469853
########## cpp/src/arrow/compute/kernels/vector_search_sorted.cc: ########## @@ -0,0 +1,1215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_vector.h" + +#include <algorithm> +#include <memory> +#include <numeric> +#include <optional> +#include <ranges> +#include <type_traits> +#include <utility> + +#include "arrow/array/array_primitive.h" +#include "arrow/array/array_run_end.h" +#include "arrow/array/concatenate.h" +#include "arrow/array/util.h" +#include "arrow/buffer_builder.h" +#include "arrow/chunk_resolver.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/kernels/vector_sort_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/compute/registry_internal.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging_internal.h" +#include "arrow/util/ree_util.h" +#include "arrow/util/unreachable.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute::internal { +namespace { + +/// Return the static default options instance used by the meta-function. +const SearchSortedOptions* GetDefaultSearchSortedOptions() { + static const auto kDefaultSearchSortedOptions = SearchSortedOptions::Defaults(); + return &kDefaultSearchSortedOptions; +} + +const FunctionDoc search_sorted_doc( + "Find insertion indices for sorted input", + ("Return the index where each needle should be inserted in a sorted input array\n" + "to maintain ascending order.\n" + "\n" + "With side='left', returns the first suitable index (lower bound).\n" + "With side='right', returns the last suitable index (upper bound).\n" + "\n" + "The searched values may be provided as an array or chunked array and must\n" + "already be sorted in ascending order. Null values in the searched array are\n" + "supported when clustered entirely at the start or\n" + "entirely at the end. Non-null needles are matched only against the non-null\n" + "portion of the searched array. Needles may be a scalar, array, or chunked\n" + "array. Null needles emit nulls in the output."), + {"values", "needles"}, "SearchSortedOptions"); + +// This file implements search_sorted as a normalization pipeline around one +// typed binary-search core. +// +// The searched values are first validated, unwrapped to their logical type, +// and adapted to a uniform accessor interface. Plain arrays and chunked arrays +// expose logical element access directly. Run-end encoded (REE) arrays expose a +// search domain over physical runs while still translating insertion positions +// back to logical indices. +// +// Values null handling is normalized before any search happens. Nulls are only +// accepted when clustered entirely at the start or entirely at the end of the +// sorted values. The implementation computes the contiguous non-null logical +// window once and then searches only within that window. For REE values this +// requires logical null counting, because nullness lives in the values child +// rather than in a top-level validity bitmap. +// +// Needles follow two execution paths. Scalars, plain arrays, and chunked arrays +// are visited element by element through one callback interface, producing one +// insertion index per logical needle and propagating null needles as null +// outputs. REE needles are handled separately: the kernel searches each +// physical REE value once, rebuilds a temporary REE UInt64 result with the same +// logical run ends, and then run-end decodes it back to the dense public +// output shape. +// +// The actual comparison/search step is shared across all normalized inputs. +// After dispatching to the logical/physical Arrow representation, the kernel +// runs a lower-bound or upper-bound binary search depending on +// `SearchSortedOptions::side`, then maps the found position back to the caller- +// visible logical insertion index. +// +// Output materialization is centralized in a UInt64 builder with an optional +// validity bitmap. Non-null-only needles only build the values buffer, while +// nullable needles also emit the null bitmap. +// +// High-level flow: +// +// values datum +// | +// +--> ValidateSortedValuesInput +// | +// +--> LogicalType / FindNonNullValuesRange +// | +// +--> VisitValuesAccessor +// | +// +--> PlainArrayAccessor +// | +// +--> RunEndEncodedValuesAccessor +// | +// +--> ChunkedArrayAccessor +// | +// `--> ChunkedRunEndEncodedValuesAccessor +// +// needles datum +// | +// +--> ValidateNeedleInput +// | +// +--> DatumHasNulls +// | +// +--> REE needles +// | +--> search physical runs once +// | +--> rebuild temporary REE uint64 result +// | `--> RunEndDecode back to dense output +// | +// `--> VisitNeedleRuns +// | +// +--> scalar needle -> one logical element +// | +// +--> plain array -> one logical element per slot +// | +// `--> chunked input -> recurse chunk by chunk +// +// normalized values accessor + normalized needle runs +// | +// `--> FindInsertionPoint<T> +// | +// +--> side = left -> lower_bound semantics +// | +// `--> side = right -> upper_bound semantics +// +// result materialization +// | +// +--> no needle nulls +// | `--> InsertionIndexBuilder<false> +// | `--> fill uint64 buffer directly +// | +// `--> nullable needles +// `--> InsertionIndexBuilder<true> +// +--> AppendNulls for null runs +// `--> bulk fill repeated indices and validity bits +// +// A rough map of the file: +// +// [validation + type helpers] +// | +// [value accessors] +// | +// [needle visitors] +// | +// [typed search + output helpers] +// | +// [meta-function dispatch] +// + +#define VISIT_SEARCH_SORTED_PHYSICAL_TYPES(VISIT) \ + VISIT(BooleanType) \ + VISIT(Int8Type) \ + VISIT(Int16Type) \ + VISIT(Int32Type) \ + VISIT(Int64Type) \ + VISIT(UInt8Type) \ + VISIT(UInt16Type) \ + VISIT(UInt32Type) \ + VISIT(UInt64Type) \ + VISIT(FloatType) \ + VISIT(DoubleType) \ + VISIT(BinaryType) \ + VISIT(LargeBinaryType) \ + VISIT(BinaryViewType) + +template <typename ArrowType> +using SearchValue = typename GetViewType<ArrowType>::T; + +struct NonNullValuesRange { + int64_t offset = 0; + int64_t length = 0; + + /// Return whether the range spans the full searched values input. + bool is_identity(int64_t full_length) const { + return (offset == 0) && (length == full_length); + } +}; + +// Convert ArrayData to its physical representation so that typed accessors +// can be constructed with a physical ArrowType (e.g. Date32 → Int32). +// For REE arrays, only the values child type is converted; the REE wrapper +// type stays unchanged. +inline std::shared_ptr<ArrayData> ToPhysicalData( + const std::shared_ptr<ArrayData>& data, + const std::shared_ptr<DataType>& physical_type) { + if (data->type->id() == Type::RUN_END_ENCODED) { + auto result = data->Copy(); + auto values_copy = result->child_data[1]->Copy(); + values_copy->type = physical_type; + result->child_data[1] = std::move(values_copy); + return result; + } + auto result = data->Copy(); + result->type = physical_type; + return result; +} + +/// Read a run-end value from any supported run-end integer representation. +inline int64_t GetRunEndValue(const ArraySpan& run_ends, int64_t physical_index) { + switch (run_ends.type->id()) { + case Type::INT16: + return run_ends.GetValues<int16_t>(1)[physical_index]; + case Type::INT32: + return run_ends.GetValues<int32_t>(1)[physical_index]; + case Type::INT64: + return run_ends.GetValues<int64_t>(1)[physical_index]; + default: + DCHECK(false) << "Unexpected run-end type for search_sorted values: " + << run_ends.type->ToString(); + return 0; + } +} + +/// Comparator implementing Arrow's ascending-order semantics for supported types. +template <typename ArrowType> +struct SearchSortedCompare { + using ValueType = SearchValue<ArrowType>; + + int operator()(const ValueType& left, const ValueType& right) const { + return CompareTypeValues<ArrowType>(left, right, SortOrder::Ascending, + NullPlacement::AtEnd); + } +}; + +/// Access logical values from a plain Arrow array. +template <typename ArrowType> +class PlainArrayAccessor { + public: + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using ValueType = SearchValue<ArrowType>; + + /// Build a typed accessor over a plain array payload. + explicit PlainArrayAccessor(const std::shared_ptr<ArrayData>& array_data) + : array_(array_data) {} + + /// Return the logical length of the searched values. + int64_t length() const { return array_.length(); } + + /// Return the logical value at the given logical position. + ValueType Value(int64_t index) const { + return GetViewType<ArrowType>::LogicalValue(array_.GetView(index)); + } + + /// Convert a binary-search position in the plain array directly back to the + /// logical insertion index returned to callers. + uint64_t LogicalInsertionIndex(int64_t index) const { + return static_cast<uint64_t>(index); + } + + private: + ArrayType array_; +}; + +/// Access logical values from a run-end encoded Arrow array. +template <typename ArrowType> +class RunEndEncodedValuesAccessor { + public: + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using ValueType = SearchValue<ArrowType>; + + /// Build a typed accessor over a run-end encoded payload. + explicit RunEndEncodedValuesAccessor(const RunEndEncodedArray& array) + : array_(array), + values_(array.values()->data()), + array_span_(*array.data()), + physical_range_(::arrow::ree_util::FindPhysicalRange(array_span_, array.offset(), + array.length())) {} + + /// Return the number of physical runs used as the search domain. + int64_t length() const { return physical_range_.second; } + + /// Return the logical value at the given physical run position. + ValueType Value(int64_t index) const { + const auto physical_index = physical_range_.first + index; + return GetViewType<ArrowType>::LogicalValue(values_.GetView(physical_index)); + } + + /// Return the number of null physical runs in the selected physical range. + int64_t NullCount() const { + return values_.Slice(physical_range_.first, physical_range_.second)->null_count(); + } + + /// Translate a binary-search position over physical runs into a logical array + /// insertion index. + uint64_t LogicalInsertionIndex(int64_t index) const { + DCHECK_GE(index, 0); + DCHECK_LE(index, physical_range_.second); + + if (index == 0) { + return 0; + } + if (index == physical_range_.second) { + return static_cast<uint64_t>(array_.length()); + } + return static_cast<uint64_t>(LogicalRunEnd(physical_range_.first + index - 1)); + } + + /// Return the logical length of the sliced REE values view. + int64_t logical_length() const { return array_.length(); } + + private: + /// Return the logical run end corresponding to a physical run index. + int64_t LogicalRunEnd(int64_t physical_index) const { + // The run-end value is an absolute (cumulative) logical position in the + // full array. Subtract array_.offset() to get a position relative to the + // current slice. Clamp to 0, when the slice offset falls in the middle of + // a physical run the first runend after the slice start is always positive, + // but defensive clamping guards against edge cases where a run-end lands + // exactly at (or before) the slice offset. + const int64_t logical_run_end = std::max<int64_t>( + GetRunEndValue(::arrow::ree_util::RunEndsArray(array_span_), physical_index) - + array_.offset(), + 0); + // The physical range returned by FindPhysicalRange may include a trailing + // run that extends beyond the logical slice. Clamp to array_.length() so + // the result stays within the slice boundary. + return std::min(logical_run_end, array_.length()); + } + + const RunEndEncodedArray& array_; + ArrayType values_; + ArraySpan array_span_; + std::pair<int64_t, int64_t> physical_range_; +}; + +/// Access logical values from a chunked Arrow array without combining chunks. +template <typename ArrowType> +class ChunkedArrayAccessor { + public: + using ArrayType = typename TypeTraits<ArrowType>::ArrayType; + using ValueType = SearchValue<ArrowType>; + + /// Build an accessor that resolves logical indices across chunk boundaries + /// without concatenating the input. + explicit ChunkedArrayAccessor(const ChunkedArray& chunked_array) + : chunked_array_(chunked_array), resolver_(chunked_array.chunks()) { + chunks_.reserve(static_cast<size_t>(chunked_array_.num_chunks())); + for (const auto& chunk : chunked_array_.chunks()) { + DCHECK_NE(chunk->type_id(), Type::RUN_END_ENCODED); + chunks_.emplace_back(chunk->data()); + } + } + + /// Return the total logical length across all chunks. + int64_t length() const { return chunked_array_.length(); } + + /// Resolve a logical index to its chunk-local storage and return that value. + ValueType Value(int64_t index) const { + const auto location = resolver_.Resolve(index); + DCHECK_LT(location.chunk_index, chunked_array_.num_chunks()); + return GetViewType<ArrowType>::LogicalValue( + chunks_[location.chunk_index].GetView(location.index_in_chunk)); + } + + /// Chunked plain arrays already operate on logical indices directly. + uint64_t LogicalInsertionIndex(int64_t index) const { + return static_cast<uint64_t>(index); + } + + private: + const ChunkedArray& chunked_array_; + ChunkResolver resolver_; + std::vector<ArrayType> chunks_; +}; + +template <typename ArrowType> +class ChunkedRunEndEncodedValuesAccessor { + public: + using ValueType = SearchValue<ArrowType>; + + /// Flatten a chunked REE input into a logical sequence of physical runs while + /// preserving enough offset information to map search results back to logical + /// array positions. + explicit ChunkedRunEndEncodedValuesAccessor(const ChunkedArray& chunked_array) + : chunked_array_(chunked_array), logical_length_(chunked_array.length()) { + const auto chunk_count = chunked_array_.num_chunks(); + logical_offsets_.reserve(static_cast<size_t>(chunk_count)); + accessors_.reserve(static_cast<size_t>(chunk_count)); + std::vector<int64_t> run_offsets; + run_offsets.reserve(static_cast<size_t>(chunk_count) + 1); + run_offsets.push_back(0); + + int64_t selected_run_start = 0; + int64_t selected_logical_start = 0; + + for (const auto& chunk : chunked_array_.chunks()) { + if (chunk->length() != 0) { + DCHECK_EQ(chunk->type_id(), Type::RUN_END_ENCODED); + + const auto& ree_chunk = checked_cast<const RunEndEncodedArray&>(*chunk); + logical_offsets_.push_back(selected_logical_start); + accessors_.emplace_back(ree_chunk); + + selected_run_start += accessors_.back().length(); + selected_logical_start += chunk->length(); + run_offsets.push_back(selected_run_start); + } + } + + DCHECK_EQ(selected_logical_start, logical_length_); + total_run_count_ = selected_run_start; + run_resolver_.emplace(std::move(run_offsets)); + } + + /// Return the total number of searchable physical runs across all chunks. + int64_t length() const { return total_run_count_; } + + /// Resolve a global physical-run index to the owning chunk accessor. + ValueType Value(int64_t index) const { + const auto [chunk_index, local_index] = ResolveRun(index); + return accessors_[chunk_index].Value(local_index); + } + + /// Count leading null physical runs across chunks. Validation guarantees that + /// any null runs are clustered entirely at one end of the logical values. + int64_t NullCount() const { + int64_t null_run_count = 0; + for (const auto& accessor : accessors_) { + const auto local_null_run_count = accessor.NullCount(); + null_run_count += local_null_run_count; + if (local_null_run_count != accessor.length()) { + break; + } + } + return null_run_count; + } Review Comment: `ChunkedRunEndEncodedValuesAccessor::NullCount()` only counts leading null runs (it breaks after the first non-all-null chunk). For chunked REE values with *trailing* null runs, `MakePhysicalNonNullValuesRange()` will not exclude the null-run suffix, and the binary search can probe `Value(it)` on a null run (reading an undefined physical value). Since validation guarantees null runs are clustered at one end, this should return the *total* null-run count across all chunks (or otherwise account for trailing nulls). ########## cpp/src/arrow/compute/kernels/vector_search_sorted_test.cc: ########## @@ -0,0 +1,624 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <memory> +#include <string> +#include <vector> + +#include <gtest/gtest.h> + +#include "arrow/array/concatenate.h" +#include "arrow/compute/api.h" +#include "arrow/compute/kernels/test_util_internal.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { +namespace { + +Result<std::shared_ptr<Array>> REEFromJSON(const std::shared_ptr<DataType>& ree_type, + const std::string& json) { + auto ree_type_ptr = checked_cast<const RunEndEncodedType*>(ree_type.get()); + auto array = ArrayFromJSON(ree_type_ptr->value_type(), json); + ARROW_ASSIGN_OR_RAISE( + auto datum, RunEndEncode(array, RunEndEncodeOptions{ree_type_ptr->run_end_type()})); + return datum.make_array(); +} + +void CheckSearchSorted(const Datum& values, const Datum& needles, + SearchSortedOptions::Side side, const std::string& expected_json) { + ASSERT_OK_AND_ASSIGN(auto result, + SearchSorted(values, needles, SearchSortedOptions(side))); + ASSERT_TRUE(result.is_array()); + ASSERT_OK(result.make_array()->ValidateFull()); + + AssertArraysEqual(*ArrayFromJSON(uint64(), expected_json), *result.make_array()); +} + +void CheckSearchSorted(const Datum& values, const Datum& needles, + const std::string& expected_left_json, + const std::string& expected_right_json) { + CheckSearchSorted(values, needles, SearchSortedOptions::Left, expected_left_json); + CheckSearchSorted(values, needles, SearchSortedOptions::Right, expected_right_json); +} + +void CheckSimpleSearchSorted(const std::shared_ptr<DataType>& type, + const std::string& values_json, + const std::string& needles_json, + const std::string& expected_left_json, + const std::string& expected_right_json) { + auto values = ArrayFromJSON(type, values_json); + auto needles = ArrayFromJSON(type, needles_json); + + CheckSearchSorted(Datum(values), Datum(needles), expected_left_json, + expected_right_json); +} + +void CheckScalarSearchSorted(const Datum& values, const std::shared_ptr<Array>& needles, + const std::string& expected_left_json, + const std::string& expected_right_json) { + auto expected_left = ArrayFromJSON(uint64(), expected_left_json); + auto expected_right = ArrayFromJSON(uint64(), expected_right_json); + + ASSERT_EQ(needles->length(), expected_left->length()); + ASSERT_EQ(needles->length(), expected_right->length()); + + for (int64_t index = 0; index < needles->length(); ++index) { + ASSERT_OK_AND_ASSIGN(auto needle, needles->GetScalar(index)); + ASSERT_OK_AND_ASSIGN(auto left, + SearchSorted(values, Datum(needle), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_OK_AND_ASSIGN(auto right, + SearchSorted(values, Datum(needle), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(left.is_scalar()); + ASSERT_TRUE(right.is_scalar()); + + ASSERT_OK_AND_ASSIGN(auto expected_left_scalar, expected_left->GetScalar(index)); + ASSERT_OK_AND_ASSIGN(auto expected_right_scalar, expected_right->GetScalar(index)); + AssertScalarsEqual(*expected_left_scalar, *left.scalar()); + AssertScalarsEqual(*expected_right_scalar, *right.scalar()); + } +} + +void CheckSimpleScalarSearchSorted(const std::shared_ptr<DataType>& type, + const std::string& values_json, + const std::string& needles_json, + const std::string& expected_left_json, + const std::string& expected_right_json) { + auto values = ArrayFromJSON(type, values_json); + auto needles = ArrayFromJSON(type, needles_json); + CheckScalarSearchSorted(Datum(values), needles, expected_left_json, + expected_right_json); +} + +void CheckSimpleSearchSortedAndScalar(const std::shared_ptr<DataType>& type, + const std::string& values_json, + const std::string& needles_json, + const std::string& expected_left_json, + const std::string& expected_right_json) { + auto values = ArrayFromJSON(type, values_json); + auto needles = ArrayFromJSON(type, needles_json); + + CheckSearchSorted(Datum(values), Datum(needles), expected_left_json, + expected_right_json); + CheckScalarSearchSorted(Datum(values), needles, expected_left_json, + expected_right_json); +} + +void CheckChunkedSearchSortedAndConcatenated(const std::shared_ptr<ChunkedArray>& values, + const std::shared_ptr<ChunkedArray>& needles, + const std::string& expected_left_json, + const std::string& expected_right_json) { + CheckSearchSorted(Datum(values), Datum(needles), expected_left_json, + expected_right_json); + + ASSERT_OK_AND_ASSIGN(auto concatenated_values, Concatenate(values->chunks())); + ASSERT_OK_AND_ASSIGN(auto concatenated_needles, Concatenate(needles->chunks())); + + CheckSearchSorted(Datum(concatenated_values), Datum(concatenated_needles), + expected_left_json, expected_right_json); +} + +struct SearchSortedSmokeCase { + std::string name; + std::shared_ptr<DataType> type; + std::string values_json; + std::string needles_json; + std::string expected_left_json; + std::string expected_right_json; +}; + +std::vector<SearchSortedSmokeCase> SupportedTypeSmokeCases() { + return { + {"Boolean", boolean(), "[false, false, false, true, true]", "[false, true]", + "[0, 3]", "[3, 5]"}, + { + "Int8", + int8(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + { + "Int16", + int16(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + { + "Int32", + int32(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + { + "Int64", + int64(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + { + "UInt8", + uint8(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + { + "UInt16", + uint16(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + { + "UInt32", + uint32(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + { + "UInt64", + uint64(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + {"Float32", float32(), "[1.0, 3.0, 3.0, 5.0, 8.0]", "[0.0, 3.0, 9.0]", "[0, 1, 5]", + "[0, 3, 5]"}, + {"Float64", float64(), "[1.0, 3.0, 3.0, 5.0, 8.0]", "[0.0, 3.0, 9.0]", "[0, 1, 5]", + "[0, 3, 5]"}, + { + "Date32", + date32(), + "[1, 3, 3, 5, 8]", + "[0, 3, 9]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + { + "Date64", + date64(), + "[86400000, 259200000, 259200000, 432000000, 691200000]", + "[0, 259200000, 777600000]", + "[0, 1, 5]", + "[0, 3, 5]", + }, + {"Time32", time32(TimeUnit::SECOND), "[1, 3, 3, 5, 8]", "[0, 3, 9]", "[0, 1, 5]", + "[0, 3, 5]"}, + {"Time64", time64(TimeUnit::NANO), "[1, 3, 3, 5, 8]", "[0, 3, 9]", "[0, 1, 5]", + "[0, 3, 5]"}, + {"Timestamp", timestamp(TimeUnit::SECOND), + R"(["1970-01-02", "1970-01-04", "1970-01-04", "1970-01-06", "1970-01-09"])", + R"(["1970-01-01", "1970-01-04", "1970-01-10"])", "[0, 1, 5]", "[0, 3, 5]"}, + {"Duration", duration(TimeUnit::NANO), "[1, 3, 3, 5, 8]", "[0, 3, 9]", "[0, 1, 5]", + "[0, 3, 5]"}, + {"Binary", binary(), R"(["aa", "bb", "bb", "dd", "ff"])", R"(["a", "bb", "z"])", + "[0, 1, 5]", "[0, 3, 5]"}, + {"String", utf8(), R"(["aa", "bb", "bb", "dd", "ff"])", R"(["a", "bb", "z"])", + "[0, 1, 5]", "[0, 3, 5]"}, + {"LargeBinary", large_binary(), R"(["aa", "bb", "bb", "dd", "ff"])", + R"(["a", "bb", "z"])", "[0, 1, 5]", "[0, 3, 5]"}, + {"LargeString", large_utf8(), R"(["aa", "bb", "bb", "dd", "ff"])", + R"(["a", "bb", "z"])", "[0, 1, 5]", "[0, 3, 5]"}, + {"BinaryView", binary_view(), R"(["aa", "bb", "bb", "dd", "ff"])", + R"(["a", "bb", "z"])", "[0, 1, 5]", "[0, 3, 5]"}, + {"StringView", utf8_view(), R"(["aa", "bb", "bb", "dd", "ff"])", + R"(["a", "bb", "z"])", "[0, 1, 5]", "[0, 3, 5]"}, + }; +} + +class SearchSortedSupportedTypesTest + : public ::testing::TestWithParam<SearchSortedSmokeCase> {}; + +TEST(SearchSorted, BasicLeftRight) { + CheckSimpleSearchSorted(int64(), "[100, 200, 200, 300, 300]", "[50, 200, 250, 400]", + "[0, 1, 3, 5]", "[0, 3, 3, 5]"); +} + +TEST(SearchSorted, ScalarNeedle) { + auto values = ArrayFromJSON(int32(), "[1, 3, 5, 7]"); + + ASSERT_OK_AND_ASSIGN( + auto result, SearchSorted(Datum(values), Datum(std::make_shared<Int32Scalar>(5)), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(result.is_scalar()); + ASSERT_EQ(checked_cast<const UInt64Scalar&>(*result.scalar()).value, 3); +} + +TEST(SearchSorted, ScalarStringNeedle) { + auto values = ArrayFromJSON(utf8(), R"(["aa", "bb", "bb", "cc"])"); + + ASSERT_OK_AND_ASSIGN( + auto result, + SearchSorted(Datum(values), Datum(std::make_shared<StringScalar>("bb")), + SearchSortedOptions(SearchSortedOptions::Right))); + + ASSERT_TRUE(result.is_scalar()); + ASSERT_EQ(checked_cast<const UInt64Scalar&>(*result.scalar()).value, 3); +} + +TEST(SearchSorted, EmptyHaystack) { + CheckSimpleSearchSorted(int16(), "[]", "[1, 2, 3]", "[0, 0, 0]", "[0, 0, 0]"); +} + +TEST(SearchSorted, ValuesWithLeadingNulls) { + CheckSimpleSearchSorted(int32(), "[null, 200, 300, 300]", "[50, 200, 250, 400]", + "[1, 1, 2, 4]", "[1, 2, 2, 4]"); +} + +TEST(SearchSorted, ValuesAllNull) { + CheckSimpleSearchSorted(int32(), "[null, null, null]", "[50, 200, null]", + "[0, 0, null]", "[0, 0, null]"); +} + +TEST(SearchSorted, ValuesWithTrailingNulls) { + CheckSimpleSearchSorted(int32(), "[200, 300, 300, null, null]", "[50, 200, 250, 400]", + "[0, 0, 1, 3]", "[0, 1, 1, 3]"); +} + +TEST(SearchSorted, ValuesWithInterspersedNullsAreRejected) { + auto values = ArrayFromJSON(int32(), "[null, 200, null]"); + auto needles = ArrayFromJSON(int32(), "[200]"); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("search_sorted values with nulls must be clustered at the " + "start or end."), + SearchSorted(Datum(values), Datum(needles))); +} + +TEST(SearchSorted, FloatValuesWithTrailingNaNsAndNulls) { + CheckSimpleSearchSortedAndScalar(float64(), "[1.0, 3.0, 3.0, 5.0, NaN, NaN, null]", + "[0.0, 3.0, 4.0, NaN]", "[0, 1, 3, 4]", + "[0, 3, 3, 6]"); +} + +TEST(SearchSorted, FloatValuesWithTrailingNaNsAndNullsAndNullNeedles) { + CheckSimpleSearchSortedAndScalar(float64(), "[1.0, 3.0, 3.0, 5.0, NaN, NaN, null]", + "[0.0, 3.0, 4.0, NaN, null]", "[0, 1, 3, 4, null]", + "[0, 3, 3, 6, null]"); +} + +TEST(SearchSorted, FloatValuesWithLeadingNullsAndTrailingNaNsAndNullNeedles) { + CheckSimpleSearchSortedAndScalar(float64(), "[null, 1.0, 3.0, 3.0, 5.0, NaN, NaN]", + "[0.0, 3.0, 4.0, NaN, null]", "[1, 2, 4, 5, null]", + "[1, 4, 4, 7, null]"); +} + +TEST(SearchSorted, NullNeedlesEmitNull) { + CheckSimpleSearchSorted(int32(), "[null, 200, 300, 300]", "[null, 50, 200, null, 400]", + "[null, 1, 1, null, 4]", "[null, 1, 2, null, 4]"); + + auto values = ArrayFromJSON(int32(), "[null, 200, 300, 300]"); + + ASSERT_OK_AND_ASSIGN(auto scalar_result, + SearchSorted(Datum(values), Datum(std::make_shared<Int32Scalar>()), + SearchSortedOptions(SearchSortedOptions::Left))); + ASSERT_TRUE(scalar_result.is_scalar()); + ASSERT_FALSE(scalar_result.scalar()->is_valid); + ASSERT_TRUE(scalar_result.scalar()->type->Equals(uint64())); +} + +TEST(SearchSorted, ChunkedValues) { + auto values = std::make_shared<ChunkedArray>(ArrayVector{ + ArrayFromJSON(int32(), "[10, 10]"), + ArrayFromJSON(int32(), "[10, 30, 50]"), + }); + auto needles = ArrayFromJSON(int32(), "[10, 20, 60]"); + + CheckSearchSorted(Datum(values), Datum(needles), "[0, 3, 5]", "[3, 3, 5]"); +} + +TEST(SearchSorted, ChunkedNeedles) { + auto values = ArrayFromJSON(int32(), "[1, 1, 3, 5, 8]"); + auto needles = std::make_shared<ChunkedArray>(ArrayVector{ + ArrayFromJSON(int32(), "[null, 0, 1]"), + ArrayFromJSON(int32(), "[4, null, 9]"), + }); + + CheckSearchSorted(Datum(values), Datum(needles), "[null, 0, 0, 3, null, 5]", + "[null, 0, 2, 3, null, 5]"); +} + +TEST(SearchSorted, ChunkedValuesChunkedNeedles) { + auto values = std::make_shared<ChunkedArray>(ArrayVector{ + ArrayFromJSON(int32(), "[1, 1]"), + ArrayFromJSON(int32(), "[3]"), + ArrayFromJSON(int32(), "[5, 8]"), + }); + auto needles = std::make_shared<ChunkedArray>(ArrayVector{ + ArrayFromJSON(int32(), "[null, 0, 1]"), + ArrayFromJSON(int32(), "[4]"), + ArrayFromJSON(int32(), "[null, 9]"), + }); + + CheckChunkedSearchSortedAndConcatenated(values, needles, "[null, 0, 0, 3, null, 5]", + "[null, 0, 2, 3, null, 5]"); +} + +TEST(SearchSorted, ChunkedRunEndEncodedValues) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto left_chunk, REEFromJSON(values_type, "[10, 10, 10]")); + ASSERT_OK_AND_ASSIGN(auto right_chunk, REEFromJSON(values_type, "[30, 30, 50]")); + auto values = std::make_shared<ChunkedArray>(ArrayVector{left_chunk, right_chunk}); + auto needles = ArrayFromJSON(int32(), "[5, 10, 20, 30, 40, 50, 60]"); + + CheckSearchSorted(Datum(values), Datum(needles), "[0, 0, 3, 3, 5, 5, 6]", + "[0, 3, 3, 5, 5, 6, 6]"); +} + +TEST(SearchSorted, SlicedChunkedRunEndEncodedValues) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto left_chunk, REEFromJSON(values_type, "[10, 10, 10]")); + ASSERT_OK_AND_ASSIGN(auto right_chunk, REEFromJSON(values_type, "[30, 30, 50]")); + auto values = std::make_shared<ChunkedArray>( + ArrayVector{left_chunk->Slice(1, 2), right_chunk->Slice(0, 2)}); + auto needles = ArrayFromJSON(int32(), "[5, 10, 20, 30, 40, 50]"); + + CheckSearchSorted(Datum(values), Datum(needles), "[0, 0, 2, 2, 4, 4]", + "[0, 2, 2, 4, 4, 4]"); +} + +TEST(SearchSorted, ChunkedRunEndEncodedNeedles) { + auto values = ArrayFromJSON(int32(), "[1, 1, 3, 5, 8]"); + auto needles_type = run_end_encoded(int32(), int32()); + ASSERT_OK_AND_ASSIGN(auto left_chunk, REEFromJSON(needles_type, "[0, 0, 1, 1]")); + ASSERT_OK_AND_ASSIGN(auto right_chunk, REEFromJSON(needles_type, "[4, 4, 9]")); + auto needles = std::make_shared<ChunkedArray>(ArrayVector{left_chunk, right_chunk}); + + CheckSearchSorted(Datum(values), Datum(needles), SearchSortedOptions::Right, + "[0, 0, 2, 2, 3, 3, 5]"); +} + +TEST(SearchSorted, ChunkedRunEndEncodedValuesLeadingNullsAcrossEmptyChunks) { + auto values_type = run_end_encoded(int16(), int32()); + ASSERT_OK_AND_ASSIGN(auto empty_chunk, REEFromJSON(values_type, "[]")); + ASSERT_OK_AND_ASSIGN(auto null_chunk, REEFromJSON(values_type, "[null, null]")); + ASSERT_OK_AND_ASSIGN(auto data_chunk, REEFromJSON(values_type, "[2, 4, 4]")); + auto values = std::make_shared<ChunkedArray>( + ArrayVector{empty_chunk, null_chunk, empty_chunk, data_chunk}); + auto needles = ArrayFromJSON(int32(), "[1, 4, 8]"); + + CheckSearchSorted(Datum(values), Datum(needles), "[2, 3, 5]", "[2, 5, 5]"); +} + +TEST(SearchSorted, ChunkedRunEndEncodedAllNullValuesAcrossEmptyChunks) { Review Comment: There is test coverage for chunked REE values with leading nulls across empty chunks, but not for the symmetric case where null runs are clustered at the end across chunk boundaries/empty chunks. Adding that case would catch issues in the chunked-REE null-run trimming logic. ########## docs/source/cpp/compute.rst: ########## @@ -1901,6 +1903,13 @@ in the respective option classes. * \(7) The output is an array of indices into the input, that define a non-stable sort of the input. +* \(8) The first argument must be sorted in ascending order. If it contains + nulls, they must be clustered entirely at the start or the end, and non-null + needles are matched only against the non-null portion. The second argument + may be a scalar, array, or run-end encoded array. Null needles yield null + outputs. Both arguments must have the same logical type. A scalar needle + yields a UInt64 scalar; otherwise the result is a UInt64 array. Review Comment: The C++ compute docs entry for (8) doesn’t mention that `search_sorted` supports chunked arrays (and chunked run-end encoded arrays) for both inputs, even though the kernel and C++ API do. This can mislead users into thinking they need to concatenate first. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
