This is an automated email from the ASF dual-hosted git repository. wesm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new da752fd ARROW-2104: [C++] take kernel functions for nested types da752fd is described below commit da752fddab34d71e5c5f648b2cb20740c16ce11e Author: Benjamin Kietzman <bengil...@gmail.com> AuthorDate: Thu Jun 27 15:45:14 2019 -0500 ARROW-2104: [C++] take kernel functions for nested types Take now supports gathering from List, FixedSizeList, Map, and Struct arrays. Union is not yet supported Author: Benjamin Kietzman <bengil...@gmail.com> Closes #4531 from bkietz/2104-Implement-take-kernel-functions-nested-a and squashes the following commits: 73262bd44 <Benjamin Kietzman> clang-format eaf8302ea <Benjamin Kietzman> add benchmarks for Take() 5981ee8d4 <Benjamin Kietzman> rewrite Filter(string array) benchmark to respect memory budget d60ff7c0d <Benjamin Kietzman> cast size_t -> int16_t, update fixed_size_binary(0) test 30d587252 <Benjamin Kietzman> add LiteralType constructor for gcc 4.8 e73c1ec23 <Benjamin Kietzman> validate arrays in pyarrow's Take() test ac0e391aa <Benjamin Kietzman> add benchmark for filtering a StringArray 0fe81648e <Benjamin Kietzman> added requested tests and ValidateArray calls 55854836d <Benjamin Kietzman> add doccomments e6081b027 <Benjamin Kietzman> remove redundant bounds checking in Struct case d9c4a1a64 <Benjamin Kietzman> add Take() permutation inversion test abc1733bd <Benjamin Kietzman> simplify looping through IndexSequences c3e812982 <Benjamin Kietzman> rewrite python Take() test c7f2e4021 <Benjamin Kietzman> repair bounds checking 6a14c93e8 <Benjamin Kietzman> clang-format, explicit cast 6c453f334 <Benjamin Kietzman> lint fixes 227ea5516 <Benjamin Kietzman> add tests for Take(nested types) 65dcd9075 <Benjamin Kietzman> refactor Take and Filter to share code through Taker<> --- cpp/src/arrow/array-test.cc | 2 +- cpp/src/arrow/array.cc | 13 +- cpp/src/arrow/array.h | 8 +- cpp/src/arrow/array/builder_primitive.cc | 4 +- cpp/src/arrow/buffer-builder.h | 3 + cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + cpp/src/arrow/compute/kernels/filter-benchmark.cc | 31 ++ cpp/src/arrow/compute/kernels/filter-test.cc | 57 ++- cpp/src/arrow/compute/kernels/filter.cc | 426 ++--------------- cpp/src/arrow/compute/kernels/filter.h | 9 +- cpp/src/arrow/compute/kernels/take-benchmark.cc | 147 ++++++ cpp/src/arrow/compute/kernels/take-internal.h | 553 ++++++++++++++++++++++ cpp/src/arrow/compute/kernels/take-test.cc | 386 +++++++++++++-- cpp/src/arrow/compute/kernels/take.cc | 226 +++------ cpp/src/arrow/compute/kernels/take.h | 32 +- cpp/src/arrow/compute/kernels/util-internal.h | 2 +- python/pyarrow/tests/test_compute.py | 20 +- 17 files changed, 1303 insertions(+), 617 deletions(-) diff --git a/cpp/src/arrow/array-test.cc b/cpp/src/arrow/array-test.cc index 606ca71..2005a0d 100644 --- a/cpp/src/arrow/array-test.cc +++ b/cpp/src/arrow/array-test.cc @@ -1311,7 +1311,7 @@ TEST_F(TestFWBinaryArray, ZeroSize) { const auto& fw_array = checked_cast<const FixedSizeBinaryArray&>(*array); // data is never allocated - ASSERT_TRUE(fw_array.values() == nullptr); + ASSERT_EQ(fw_array.values()->size(), 0); ASSERT_EQ(0, fw_array.byte_width()); ASSERT_EQ(6, array->length()); diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 95acc6b..9b66af2 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -301,12 +301,21 @@ MapArray::MapArray(const std::shared_ptr<ArrayData>& data) { SetData(data); } MapArray::MapArray(const std::shared_ptr<DataType>& type, int64_t length, const std::shared_ptr<Buffer>& offsets, - const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& values, const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count, int64_t offset) { + SetData(ArrayData::Make(type, length, {null_bitmap, offsets}, {values->data()}, + null_count, offset)); +} + +MapArray::MapArray(const std::shared_ptr<DataType>& type, int64_t length, + const std::shared_ptr<Buffer>& offsets, + const std::shared_ptr<Array>& keys, + const std::shared_ptr<Array>& items, + const std::shared_ptr<Buffer>& null_bitmap, int64_t null_count, + int64_t offset) { auto pair_data = ArrayData::Make(type->children()[0]->type(), keys->data()->length, - {nullptr}, {keys->data(), values->data()}, 0, offset); + {nullptr}, {keys->data(), items->data()}, 0, offset); auto map_data = ArrayData::Make(type, length, {null_bitmap, offsets}, {pair_data}, null_count, offset); SetData(map_data); diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index 5cca9db..1e163b7 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -565,7 +565,13 @@ class ARROW_EXPORT MapArray : public ListArray { MapArray(const std::shared_ptr<DataType>& type, int64_t length, const std::shared_ptr<Buffer>& value_offsets, - const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& values, + const std::shared_ptr<Array>& keys, const std::shared_ptr<Array>& items, + const std::shared_ptr<Buffer>& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + MapArray(const std::shared_ptr<DataType>& type, int64_t length, + const std::shared_ptr<Buffer>& value_offsets, + const std::shared_ptr<Array>& values, const std::shared_ptr<Buffer>& null_bitmap = NULLPTR, int64_t null_count = kUnknownNullCount, int64_t offset = 0); diff --git a/cpp/src/arrow/array/builder_primitive.cc b/cpp/src/arrow/array/builder_primitive.cc index 34d198e..c7d934f 100644 --- a/cpp/src/arrow/array/builder_primitive.cc +++ b/cpp/src/arrow/array/builder_primitive.cc @@ -65,9 +65,9 @@ Status BooleanBuilder::Resize(int64_t capacity) { } Status BooleanBuilder::FinishInternal(std::shared_ptr<ArrayData>* out) { - std::shared_ptr<Buffer> data, null_bitmap; - RETURN_NOT_OK(data_builder_.Finish(&data)); + std::shared_ptr<Buffer> null_bitmap, data; RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap)); + RETURN_NOT_OK(data_builder_.Finish(&data)); *out = ArrayData::Make(boolean(), length_, {null_bitmap, data}, null_count_); diff --git a/cpp/src/arrow/buffer-builder.h b/cpp/src/arrow/buffer-builder.h index f069ea4..85f36ee 100644 --- a/cpp/src/arrow/buffer-builder.h +++ b/cpp/src/arrow/buffer-builder.h @@ -145,6 +145,9 @@ class ARROW_EXPORT BufferBuilder { ARROW_RETURN_NOT_OK(Resize(size_, shrink_to_fit)); if (size_ != 0) buffer_->ZeroPadding(); *out = buffer_; + if (*out == NULLPTR) { + ARROW_RETURN_NOT_OK(AllocateBuffer(pool_, 0, out)); + } Reset(); return Status::OK(); } diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 1bbb5bc..3d9da8b 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -34,3 +34,4 @@ add_arrow_benchmark(compare-benchmark PREFIX "arrow-compute") add_arrow_test(take-test PREFIX "arrow-compute") add_arrow_test(filter-test PREFIX "arrow-compute") add_arrow_benchmark(filter-benchmark PREFIX "arrow-compute") +add_arrow_benchmark(take-benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/filter-benchmark.cc b/cpp/src/arrow/compute/kernels/filter-benchmark.cc index 3eb460a..0ae528b 100644 --- a/cpp/src/arrow/compute/kernels/filter-benchmark.cc +++ b/cpp/src/arrow/compute/kernels/filter-benchmark.cc @@ -68,6 +68,30 @@ static void FilterFixedSizeList1Int64(benchmark::State& state) { } } +static void FilterString(benchmark::State& state) { + RegressionArgs args(state); + + int32_t string_min_length = 0, string_max_length = 128; + int32_t string_mean_length = (string_max_length + string_min_length) / 2; + // for an array of 50% null strings, we need to generate twice as many strings + // to ensure that they have an average of args.size total characters + auto array_size = + static_cast<int64_t>(args.size / string_mean_length / (1 - args.null_proportion)); + + auto rand = random::RandomArrayGenerator(kSeed); + auto array = std::static_pointer_cast<StringArray>(rand.String( + array_size, string_min_length, string_max_length, args.null_proportion)); + auto filter = std::static_pointer_cast<BooleanArray>( + rand.Boolean(array_size, 0.75, args.null_proportion)); + + FunctionContext ctx; + for (auto _ : state) { + Datum out; + ABORT_NOT_OK(Filter(&ctx, Datum(array), Datum(filter), &out)); + benchmark::DoNotOptimize(out); + } +} + BENCHMARK(FilterInt64) ->Apply(RegressionSetArgs) ->Args({1 << 20, 1}) @@ -82,5 +106,12 @@ BENCHMARK(FilterFixedSizeList1Int64) ->MinTime(1.0) ->Unit(benchmark::TimeUnit::kNanosecond); +BENCHMARK(FilterString) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/filter-test.cc b/cpp/src/arrow/compute/kernels/filter-test.cc index 7b34949..033efee 100644 --- a/cpp/src/arrow/compute/kernels/filter-test.cc +++ b/cpp/src/arrow/compute/kernels/filter-test.cc @@ -34,6 +34,8 @@ namespace compute { using internal::checked_pointer_cast; using util::string_view; +constexpr auto kSeed = 0x0ff1ce; + template <typename ArrowType> class TestFilterKernel : public ComputeFixture, public TestBase { protected: @@ -42,23 +44,29 @@ class TestFilterKernel : public ComputeFixture, public TestBase { const std::shared_ptr<Array>& expected) { std::shared_ptr<Array> actual; ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter, &actual)); + ASSERT_OK(ValidateArray(*actual)); AssertArraysEqual(*expected, *actual); } + void AssertFilter(const std::shared_ptr<DataType>& type, const std::string& values, const std::string& filter, const std::string& expected) { std::shared_ptr<Array> actual; ASSERT_OK(this->Filter(type, values, filter, &actual)); + ASSERT_OK(ValidateArray(*actual)); AssertArraysEqual(*ArrayFromJSON(type, expected), *actual); } + Status Filter(const std::shared_ptr<DataType>& type, const std::string& values, const std::string& filter, std::shared_ptr<Array>* out) { return arrow::compute::Filter(&this->ctx_, *ArrayFromJSON(type, values), *ArrayFromJSON(boolean(), filter), out); } + void ValidateFilter(const std::shared_ptr<Array>& values, const std::shared_ptr<Array>& filter_boxed) { std::shared_ptr<Array> filtered; ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter_boxed, &filtered)); + ASSERT_OK(ValidateArray(*filtered)); auto filter = checked_pointer_cast<BooleanArray>(filter_boxed); int64_t values_i = 0, filtered_i = 0; @@ -84,11 +92,13 @@ class TestFilterKernelWithNull : public TestFilterKernel<NullType> { protected: void AssertFilter(const std::string& values, const std::string& filter, const std::string& expected) { - TestFilterKernel<NullType>::AssertFilter(utf8(), values, filter, expected); + TestFilterKernel<NullType>::AssertFilter(null(), values, filter, expected); } }; TEST_F(TestFilterKernelWithNull, FilterNull) { + this->AssertFilter("[]", "[]", "[]"); + this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]"); this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]"); } @@ -102,6 +112,8 @@ class TestFilterKernelWithBoolean : public TestFilterKernel<BooleanType> { }; TEST_F(TestFilterKernelWithBoolean, FilterBoolean) { + this->AssertFilter("[]", "[]", "[]"); + this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]"); this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]"); this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]"); @@ -114,6 +126,7 @@ class TestFilterKernelWithNumeric : public TestFilterKernel<ArrowType> { const std::string& expected) { TestFilterKernel<ArrowType>::AssertFilter(type_singleton(), values, filter, expected); } + std::shared_ptr<DataType> type_singleton() { return TypeTraits<ArrowType>::type_singleton(); } @@ -135,13 +148,16 @@ TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) { this->AssertFilter("[null, 8, 9]", "[0, 1, 0]", "[8]"); this->AssertFilter("[7, 8, 9]", "[null, 1, 0]", "[null, 8]"); this->AssertFilter("[7, 8, 9]", "[1, null, 1]", "[7, null, 9]"); + + std::shared_ptr<Array> arr; + ASSERT_RAISES(Invalid, this->Filter(this->type_singleton(), "[7, 8, 9]", "[]", &arr)); } TYPED_TEST(TestFilterKernelWithNumeric, FilterRandomNumeric) { - auto rand = random::RandomArrayGenerator(0x5416447); + auto rand = random::RandomArrayGenerator(kSeed); for (size_t i = 3; i < 13; i++) { const int64_t length = static_cast<int64_t>(1ULL << i); - for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto null_probability : {0.0, 0.01, 0.25, 1.0}) { for (auto filter_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { auto values = rand.Numeric<TypeParam>(length, 0, 127, null_probability); auto filter = rand.Boolean(length, filter_probability, null_probability); @@ -191,7 +207,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) { using ArrayType = typename TypeTraits<TypeParam>::ArrayType; using CType = typename TypeTraits<TypeParam>::CType; - auto rand = random::RandomArrayGenerator(0x5416447); + auto rand = random::RandomArrayGenerator(kSeed); for (size_t i = 3; i < 13; i++) { const int64_t length = static_cast<int64_t>(1ULL << i); // TODO(bkietz) rewrite with some nulls @@ -206,6 +222,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) { &selection)); ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection, &filtered)); auto filtered_array = filtered.make_array(); + ASSERT_OK(ValidateArray(*filtered_array)); auto expected = CompareAndFilter<TypeParam>(array->raw_values(), array->length(), c_fifty, op); ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); @@ -216,7 +233,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) { TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) { using ArrayType = typename TypeTraits<TypeParam>::ArrayType; - auto rand = random::RandomArrayGenerator(0x5416447); + auto rand = random::RandomArrayGenerator(kSeed); for (size_t i = 3; i < 13; i++) { const int64_t length = static_cast<int64_t>(1ULL << i); auto lhs = @@ -230,6 +247,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) { &selection)); ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(lhs), selection, &filtered)); auto filtered_array = filtered.make_array(); + ASSERT_OK(ValidateArray(*filtered_array)); auto expected = CompareAndFilter<TypeParam>(lhs->raw_values(), lhs->length(), rhs->raw_values(), op); ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); @@ -242,7 +260,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) { using ArrayType = typename TypeTraits<TypeParam>::ArrayType; using CType = typename TypeTraits<TypeParam>::CType; - auto rand = random::RandomArrayGenerator(0x5416447); + auto rand = random::RandomArrayGenerator(kSeed); for (size_t i = 3; i < 13; i++) { const int64_t length = static_cast<int64_t>(1ULL << i); auto array = @@ -259,6 +277,7 @@ TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) { &selection)); ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection, &filtered)); auto filtered_array = filtered.make_array(); + ASSERT_OK(ValidateArray(*filtered_array)); auto expected = CompareAndFilter<TypeParam>( array->raw_values(), array->length(), [&](CType e) { return (e > c_fifty) && (e < c_hundred); }); @@ -313,6 +332,32 @@ TEST_F(TestFilterKernelWithList, FilterListInt32) { this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]"); } +TEST_F(TestFilterKernelWithList, FilterListListInt32) { + std::string list_json = R"([ + [], + [[1], [2, null, 2], []], + null, + [[3, null], null] + ])"; + auto type = list(list(int32())); + this->AssertFilter(type, list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(type, list_json, "[0, 1, 1, null]", R"([ + [[1], [2, null, 2], []], + null, + null + ])"); + this->AssertFilter(type, list_json, "[0, 0, 1, null]", "[null, null]"); + this->AssertFilter(type, list_json, "[1, 0, 0, 1]", R"([ + [], + [[3, null], null] + ])"); + this->AssertFilter(type, list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(type, list_json, "[0, 1, 0, 1]", R"([ + [[1], [2, null, 2], []], + [[3, null], null] + ])"); +} + class TestFilterKernelWithFixedSizeList : public TestFilterKernel<FixedSizeListType> {}; TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) { diff --git a/cpp/src/arrow/compute/kernels/filter.cc b/cpp/src/arrow/compute/kernels/filter.cc index 654ec61..8a07663 100644 --- a/cpp/src/arrow/compute/kernels/filter.cc +++ b/cpp/src/arrow/compute/kernels/filter.cc @@ -15,19 +15,17 @@ // specific language governing permissions and limitations // under the License. -#include <algorithm> +#include "arrow/compute/kernels/filter.h" + +#include <limits> #include <memory> #include <utility> -#include <vector> #include "arrow/builder.h" #include "arrow/compute/context.h" -#include "arrow/compute/kernels/filter.h" -#include "arrow/util/bit-util.h" +#include "arrow/compute/kernels/take-internal.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" -#include "arrow/util/stl.h" -#include "arrow/visitor_inline.h" namespace arrow { namespace compute { @@ -35,32 +33,36 @@ namespace compute { using internal::checked_cast; using internal::checked_pointer_cast; -template <typename Builder> -Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type, - std::unique_ptr<Builder>* out) { - std::unique_ptr<ArrayBuilder> builder; - RETURN_NOT_OK(MakeBuilder(pool, type, &builder)); - out->reset(checked_cast<Builder*>(builder.release())); - return Status::OK(); -} +// IndexSequence which yields the indices of positions in a BooleanArray +// which are either null or true +class FilterIndexSequence { + public: + // constexpr so we'll never instantiate bounds checking + constexpr bool never_out_of_bounds() const { return true; } + void set_never_out_of_bounds() {} -template <typename Builder, typename Scalar> -static Status UnsafeAppend(Builder* builder, Scalar&& value) { - builder->UnsafeAppend(std::forward<Scalar>(value)); - return Status::OK(); -} + constexpr FilterIndexSequence() = default; -static Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) { - RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size()))); - builder->UnsafeAppend(value); - return Status::OK(); -} + FilterIndexSequence(const BooleanArray& filter, int64_t out_length) + : filter_(&filter), out_length_(out_length) {} -static Status UnsafeAppend(StringBuilder* builder, util::string_view value) { - RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size()))); - builder->UnsafeAppend(value); - return Status::OK(); -} + std::pair<int64_t, bool> Next() { + // skip until an index is found at which the filter is either null or true + while (filter_->IsValid(index_) && !filter_->Value(index_)) { + ++index_; + } + bool is_valid = filter_->IsValid(index_); + return std::make_pair(index_++, is_valid); + } + + int64_t length() const { return out_length_; } + + int64_t null_count() const { return filter_->null_count(); } + + private: + const BooleanArray* filter_ = nullptr; + int64_t index_ = 0, out_length_ = -1; +}; // TODO(bkietz) this can be optimized static int64_t OutputSize(const BooleanArray& filter) { @@ -75,358 +77,32 @@ static int64_t OutputSize(const BooleanArray& filter) { return size; } -template <typename ValueType> -class FilterImpl; - -template <> -class FilterImpl<NullType> : public FilterKernel { - public: - using FilterKernel::FilterKernel; - - Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, - int64_t length, std::shared_ptr<Array>* out) override { - out->reset(new NullArray(length)); - return Status::OK(); - } -}; - -template <typename ValueType> -class FilterImpl : public FilterKernel { - public: - using ValueArray = typename TypeTraits<ValueType>::ArrayType; - using OutBuilder = typename TypeTraits<ValueType>::BuilderType; - - using FilterKernel::FilterKernel; - - Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, - int64_t length, std::shared_ptr<Array>* out) override { - std::unique_ptr<OutBuilder> builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type_, &builder)); - RETURN_NOT_OK(builder->Resize(OutputSize(filter))); - RETURN_NOT_OK(UnpackValuesNullCount(checked_cast<const ValueArray&>(values), filter, - builder.get())); - return builder->Finish(out); - } - - private: - Status UnpackValuesNullCount(const ValueArray& values, const BooleanArray& filter, - OutBuilder* builder) { - if (values.null_count() == 0) { - return UnpackIndicesNullCount<true>(values, filter, builder); - } - return UnpackIndicesNullCount<false>(values, filter, builder); - } - - template <bool AllValuesValid> - Status UnpackIndicesNullCount(const ValueArray& values, const BooleanArray& filter, - OutBuilder* builder) { - if (filter.null_count() == 0) { - return Filter<AllValuesValid, true>(values, filter, builder); - } - return Filter<AllValuesValid, false>(values, filter, builder); - } - - template <bool AllValuesValid, bool AllIndicesValid> - Status Filter(const ValueArray& values, const BooleanArray& filter, - OutBuilder* builder) { - for (int64_t i = 0; i < filter.length(); ++i) { - if (!AllIndicesValid && filter.IsNull(i)) { - builder->UnsafeAppendNull(); - continue; - } - if (!filter.Value(i)) { - continue; - } - if (!AllValuesValid && values.IsNull(i)) { - builder->UnsafeAppendNull(); - continue; - } - RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(i))); - } - return Status::OK(); - } -}; - -template <> -class FilterImpl<StructType> : public FilterKernel { +class FilterKernelImpl : public FilterKernel { public: - FilterImpl(const std::shared_ptr<DataType>& type, - std::vector<std::unique_ptr<FilterKernel>> child_kernels) - : FilterKernel(type), child_kernels_(std::move(child_kernels)) {} + FilterKernelImpl(const std::shared_ptr<DataType>& type, + std::unique_ptr<Taker<FilterIndexSequence>> taker) + : FilterKernel(type), taker_(std::move(taker)) {} Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, int64_t length, std::shared_ptr<Array>* out) override { - const auto& struct_array = checked_cast<const StructArray&>(values); - - TypedBufferBuilder<bool> null_bitmap_builder(ctx->memory_pool()); - RETURN_NOT_OK(null_bitmap_builder.Resize(length)); - - ArrayVector fields(type_->num_children()); - for (int i = 0; i < type_->num_children(); ++i) { - RETURN_NOT_OK(child_kernels_[i]->Filter(ctx, *struct_array.field(i), filter, length, - &fields[i])); - } - - for (int64_t i = 0; i < filter.length(); ++i) { - if (filter.IsNull(i)) { - null_bitmap_builder.UnsafeAppend(false); - continue; - } - if (!filter.Value(i)) { - continue; - } - if (struct_array.IsNull(i)) { - null_bitmap_builder.UnsafeAppend(false); - continue; - } - null_bitmap_builder.UnsafeAppend(true); + if (values.length() != filter.length()) { + return Status::Invalid("filter and value array must have identical lengths"); } - - auto null_count = null_bitmap_builder.false_count(); - std::shared_ptr<Buffer> null_bitmap; - RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); - - out->reset(new StructArray(type_, length, fields, null_bitmap, null_count)); - return Status::OK(); + RETURN_NOT_OK(taker_->Init(ctx->memory_pool())); + RETURN_NOT_OK(taker_->Take(values, FilterIndexSequence(filter, length))); + return taker_->Finish(out); } - private: - std::vector<std::unique_ptr<FilterKernel>> child_kernels_; -}; - -template <> -class FilterImpl<FixedSizeListType> : public FilterKernel { - public: - using FilterKernel::FilterKernel; - - Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, - int64_t length, std::shared_ptr<Array>* out) override { - const auto& list_array = checked_cast<const FixedSizeListArray&>(values); - - TypedBufferBuilder<bool> null_bitmap_builder(ctx->memory_pool()); - RETURN_NOT_OK(null_bitmap_builder.Resize(length)); - - BooleanBuilder value_filter_builder(ctx->memory_pool()); - auto list_size = list_array.list_type()->list_size(); - RETURN_NOT_OK(value_filter_builder.Resize(list_size * length)); - - for (int64_t i = 0; i < filter.length(); ++i) { - if (filter.IsNull(i)) { - null_bitmap_builder.UnsafeAppend(false); - for (int64_t j = 0; j < list_size; ++j) { - value_filter_builder.UnsafeAppendNull(); - } - continue; - } - if (!filter.Value(i)) { - for (int64_t j = 0; j < list_size; ++j) { - value_filter_builder.UnsafeAppend(false); - } - continue; - } - if (values.IsNull(i)) { - null_bitmap_builder.UnsafeAppend(false); - for (int64_t j = 0; j < list_size; ++j) { - value_filter_builder.UnsafeAppendNull(); - } - continue; - } - for (int64_t j = 0; j < list_size; ++j) { - value_filter_builder.UnsafeAppend(true); - } - null_bitmap_builder.UnsafeAppend(true); - } - - std::shared_ptr<BooleanArray> value_filter; - RETURN_NOT_OK(value_filter_builder.Finish(&value_filter)); - std::shared_ptr<Array> out_values; - RETURN_NOT_OK( - arrow::compute::Filter(ctx, *list_array.values(), *value_filter, &out_values)); - - auto null_count = null_bitmap_builder.false_count(); - std::shared_ptr<Buffer> null_bitmap; - RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); - - out->reset( - new FixedSizeListArray(type_, length, out_values, null_bitmap, null_count)); - return Status::OK(); - } -}; - -template <> -class FilterImpl<ListType> : public FilterKernel { - public: - using FilterKernel::FilterKernel; - - Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, - int64_t length, std::shared_ptr<Array>* out) override { - const auto& list_array = checked_cast<const ListArray&>(values); - - TypedBufferBuilder<bool> null_bitmap_builder(ctx->memory_pool()); - RETURN_NOT_OK(null_bitmap_builder.Resize(length)); - - BooleanBuilder value_filter_builder(ctx->memory_pool()); - - TypedBufferBuilder<int32_t> offset_builder(ctx->memory_pool()); - RETURN_NOT_OK(offset_builder.Resize(length + 1)); - int32_t offset = 0; - offset_builder.UnsafeAppend(offset); - - for (int64_t i = 0; i < filter.length(); ++i) { - if (filter.IsNull(i)) { - null_bitmap_builder.UnsafeAppend(false); - offset_builder.UnsafeAppend(offset); - RETURN_NOT_OK( - value_filter_builder.AppendValues(list_array.value_length(i), false)); - continue; - } - if (!filter.Value(i)) { - RETURN_NOT_OK( - value_filter_builder.AppendValues(list_array.value_length(i), false)); - continue; - } - if (values.IsNull(i)) { - null_bitmap_builder.UnsafeAppend(false); - offset_builder.UnsafeAppend(offset); - RETURN_NOT_OK( - value_filter_builder.AppendValues(list_array.value_length(i), false)); - continue; - } - null_bitmap_builder.UnsafeAppend(true); - offset += list_array.value_length(i); - offset_builder.UnsafeAppend(offset); - RETURN_NOT_OK(value_filter_builder.AppendValues(list_array.value_length(i), true)); - } - - std::shared_ptr<BooleanArray> value_filter; - RETURN_NOT_OK(value_filter_builder.Finish(&value_filter)); - std::shared_ptr<Array> out_values; - RETURN_NOT_OK( - arrow::compute::Filter(ctx, *list_array.values(), *value_filter, &out_values)); - - auto null_count = null_bitmap_builder.false_count(); - std::shared_ptr<Buffer> offsets, null_bitmap; - RETURN_NOT_OK(offset_builder.Finish(&offsets)); - RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); - - *out = MakeArray(ArrayData::Make(type_, length, {null_bitmap, offsets}, - {out_values->data()}, null_count)); - return Status::OK(); - } -}; - -template <> -class FilterImpl<MapType> : public FilterImpl<ListType> { - using FilterImpl<ListType>::FilterImpl; -}; - -template <> -class FilterImpl<DictionaryType> : public FilterKernel { - public: - FilterImpl(const std::shared_ptr<DataType>& type, std::unique_ptr<FilterKernel> impl) - : FilterKernel(type), impl_(std::move(impl)) {} - - Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, - int64_t length, std::shared_ptr<Array>* out) override { - auto dict_array = checked_cast<const DictionaryArray*>(&values); - // To filter a dictionary, apply the current kernel to the dictionary's indices. - std::shared_ptr<Array> taken_indices; - RETURN_NOT_OK( - impl_->Filter(ctx, *dict_array->indices(), filter, length, &taken_indices)); - return DictionaryArray::FromArrays(values.type(), taken_indices, - dict_array->dictionary(), out); - } - - private: - std::unique_ptr<FilterKernel> impl_; -}; - -template <> -class FilterImpl<ExtensionType> : public FilterKernel { - public: - FilterImpl(const std::shared_ptr<DataType>& type, std::unique_ptr<FilterKernel> impl) - : FilterKernel(type), impl_(std::move(impl)) {} - - Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, - int64_t length, std::shared_ptr<Array>* out) override { - auto ext_array = checked_cast<const ExtensionArray*>(&values); - // To take from an extension array, apply the current kernel to storage. - std::shared_ptr<Array> taken_storage; - RETURN_NOT_OK( - impl_->Filter(ctx, *ext_array->storage(), filter, length, &taken_storage)); - *out = ext_array->extension_type()->MakeArray(taken_storage->data()); - return Status::OK(); - } - - private: - std::unique_ptr<FilterKernel> impl_; + std::unique_ptr<Taker<FilterIndexSequence>> taker_; }; Status FilterKernel::Make(const std::shared_ptr<DataType>& value_type, std::unique_ptr<FilterKernel>* out) { - switch (value_type->id()) { -#define NO_CHILD_CASE(T) \ - case T##Type::type_id: \ - *out = internal::make_unique<FilterImpl<T##Type>>(value_type); \ - return Status::OK() - -#define SINGLE_CHILD_CASE(T, CHILD_TYPE) \ - case T##Type::type_id: { \ - auto t = checked_pointer_cast<T##Type>(value_type); \ - std::unique_ptr<FilterKernel> child_filter_impl; \ - RETURN_NOT_OK(FilterKernel::Make(t->CHILD_TYPE(), &child_filter_impl)); \ - *out = internal::make_unique<FilterImpl<T##Type>>(t, std::move(child_filter_impl)); \ - return Status::OK(); \ - } - - NO_CHILD_CASE(Null); - NO_CHILD_CASE(Boolean); - NO_CHILD_CASE(Int8); - NO_CHILD_CASE(Int16); - NO_CHILD_CASE(Int32); - NO_CHILD_CASE(Int64); - NO_CHILD_CASE(UInt8); - NO_CHILD_CASE(UInt16); - NO_CHILD_CASE(UInt32); - NO_CHILD_CASE(UInt64); - NO_CHILD_CASE(Date32); - NO_CHILD_CASE(Date64); - NO_CHILD_CASE(Time32); - NO_CHILD_CASE(Time64); - NO_CHILD_CASE(Timestamp); - NO_CHILD_CASE(Duration); - NO_CHILD_CASE(HalfFloat); - NO_CHILD_CASE(Float); - NO_CHILD_CASE(Double); - NO_CHILD_CASE(String); - NO_CHILD_CASE(Binary); - NO_CHILD_CASE(FixedSizeBinary); - NO_CHILD_CASE(Decimal128); - - SINGLE_CHILD_CASE(Dictionary, index_type); - SINGLE_CHILD_CASE(Extension, storage_type); + std::unique_ptr<Taker<FilterIndexSequence>> taker; + RETURN_NOT_OK(Taker<FilterIndexSequence>::Make(value_type, &taker)); - NO_CHILD_CASE(List); - NO_CHILD_CASE(FixedSizeList); - NO_CHILD_CASE(Map); - - case Type::STRUCT: { - std::vector<std::unique_ptr<FilterKernel>> child_kernels; - for (auto child : value_type->children()) { - child_kernels.emplace_back(); - RETURN_NOT_OK(FilterKernel::Make(child->type(), &child_kernels.back())); - } - *out = internal::make_unique<FilterImpl<StructType>>(value_type, - std::move(child_kernels)); - return Status::OK(); - } - -#undef NO_CHILD_CASE -#undef SINGLE_CHILD_CASE - - default: - return Status::NotImplemented("gathering values of type ", *value_type); - } + out->reset(new FilterKernelImpl(value_type, std::move(taker))); + return Status::OK(); } Status FilterKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& filter, @@ -436,26 +112,26 @@ Status FilterKernel::Call(FunctionContext* ctx, const Datum& values, const Datum } auto values_array = values.make_array(); auto filter_array = checked_pointer_cast<BooleanArray>(filter.make_array()); - const auto length = OutputSize(*filter_array); std::shared_ptr<Array> out_array; - RETURN_NOT_OK(this->Filter(ctx, *values_array, *filter_array, length, &out_array)); + RETURN_NOT_OK(this->Filter(ctx, *values_array, *filter_array, OutputSize(*filter_array), + &out_array)); *out = out_array; return Status::OK(); } -Status Filter(FunctionContext* context, const Array& values, const Array& filter, +Status Filter(FunctionContext* ctx, const Array& values, const Array& filter, std::shared_ptr<Array>* out) { Datum out_datum; - RETURN_NOT_OK(Filter(context, Datum(values.data()), Datum(filter.data()), &out_datum)); + RETURN_NOT_OK(Filter(ctx, Datum(values.data()), Datum(filter.data()), &out_datum)); *out = out_datum.make_array(); return Status::OK(); } -Status Filter(FunctionContext* context, const Datum& values, const Datum& filter, +Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter, Datum* out) { std::unique_ptr<FilterKernel> kernel; RETURN_NOT_OK(FilterKernel::Make(values.type(), &kernel)); - return kernel->Call(context, values, filter, out); + return kernel->Call(ctx, values, filter, out); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/filter.h b/cpp/src/arrow/compute/kernels/filter.h index 46ad3d4..401daa8 100644 --- a/cpp/src/arrow/compute/kernels/filter.h +++ b/cpp/src/arrow/compute/kernels/filter.h @@ -41,23 +41,22 @@ class FunctionContext; /// filter = [0, 1, 1, 0, null, 1], the output will be /// = ["b", "c", null, "f"] /// -/// \param[in] context the FunctionContext +/// \param[in] ctx the FunctionContext /// \param[in] values array to filter /// \param[in] filter indicates which values should be filtered out /// \param[out] out resulting array ARROW_EXPORT -Status Filter(FunctionContext* context, const Array& values, const Array& filter, +Status Filter(FunctionContext* ctx, const Array& values, const Array& filter, std::shared_ptr<Array>* out); /// \brief Filter an array with a boolean selection filter /// -/// \param[in] context the FunctionContext +/// \param[in] ctx the FunctionContext /// \param[in] values datum to filter /// \param[in] filter indicates which values should be filtered out /// \param[out] out resulting datum ARROW_EXPORT -Status Filter(FunctionContext* context, const Datum& values, const Datum& filter, - Datum* out); +Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter, Datum* out); /// \brief BinaryKernel implementing Filter operation class ARROW_EXPORT FilterKernel : public BinaryKernel { diff --git a/cpp/src/arrow/compute/kernels/take-benchmark.cc b/cpp/src/arrow/compute/kernels/take-benchmark.cc new file mode 100644 index 0000000..139e183 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/take-benchmark.cc @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include "arrow/compute/kernels/take.h" + +#include "arrow/compute/benchmark-util.h" +#include "arrow/compute/test-util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +namespace arrow { +namespace compute { + +constexpr auto kSeed = 0x0ff1ce; + +static void TakeBenchmark(benchmark::State& state, const std::shared_ptr<Array>& values, + const std::shared_ptr<Array>& indices) { + FunctionContext ctx; + TakeOptions options; + for (auto _ : state) { + Datum out; + ABORT_NOT_OK(Take(&ctx, Datum(values), Datum(indices), options, &out)); + benchmark::DoNotOptimize(out); + } +} + +static void TakeInt64(benchmark::State& state) { + RegressionArgs args(state); + + const int64_t array_size = args.size / sizeof(int64_t); + auto rand = random::RandomArrayGenerator(kSeed); + + auto values = rand.Int64(array_size, -100, 100, args.null_proportion); + + auto indices = rand.Int32(array_size, 0, array_size - 1, args.null_proportion); + + TakeBenchmark(state, values, indices); +} + +static void TakeFixedSizeList1Int64(benchmark::State& state) { + RegressionArgs args(state); + + const int64_t array_size = args.size / sizeof(int64_t); + auto rand = random::RandomArrayGenerator(kSeed); + + auto int_array = rand.Int64(array_size, -100, 100, args.null_proportion); + auto values = std::make_shared<FixedSizeListArray>( + fixed_size_list(int64(), 1), array_size, int_array, int_array->null_bitmap(), + int_array->null_count()); + + auto indices = rand.Int32(array_size, 0, array_size - 1, args.null_proportion); + + TakeBenchmark(state, values, indices); +} + +static void TakeInt64VsFilter(benchmark::State& state) { + RegressionArgs args(state); + + const int64_t array_size = args.size / sizeof(int64_t); + auto rand = random::RandomArrayGenerator(kSeed); + + auto values = rand.Int64(array_size, -100, 100, args.null_proportion); + + auto filter = std::static_pointer_cast<BooleanArray>( + rand.Boolean(array_size, 0.75, args.null_proportion)); + + Int32Builder indices_builder; + ABORT_NOT_OK(indices_builder.Resize(array_size)); + + for (int64_t i = 0; i < array_size; ++i) { + if (filter->IsNull(i)) { + indices_builder.UnsafeAppendNull(); + } else if (filter->Value(i)) { + indices_builder.UnsafeAppend(static_cast<int32_t>(i)); + } + } + + std::shared_ptr<Array> indices; + ABORT_NOT_OK(indices_builder.Finish(&indices)); + TakeBenchmark(state, values, indices); +} + +static void TakeString(benchmark::State& state) { + RegressionArgs args(state); + + int32_t string_min_length = 0, string_max_length = 128; + int32_t string_mean_length = (string_max_length + string_min_length) / 2; + // for an array of 50% null strings, we need to generate twice as many strings + // to ensure that they have an average of args.size total characters + auto array_size = + static_cast<int64_t>(args.size / string_mean_length / (1 - args.null_proportion)); + + auto rand = random::RandomArrayGenerator(kSeed); + auto values = std::static_pointer_cast<StringArray>(rand.String( + array_size, string_min_length, string_max_length, args.null_proportion)); + + auto indices = rand.Int32(array_size, 0, array_size - 1, args.null_proportion); + + TakeBenchmark(state, values, indices); +} + +BENCHMARK(TakeInt64) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + +BENCHMARK(TakeFixedSizeList1Int64) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + +BENCHMARK(TakeInt64VsFilter) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + +BENCHMARK(TakeString) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take-internal.h b/cpp/src/arrow/compute/kernels/take-internal.h new file mode 100644 index 0000000..bacd71b --- /dev/null +++ b/cpp/src/arrow/compute/kernels/take-internal.h @@ -0,0 +1,553 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <algorithm> +#include <limits> +#include <memory> +#include <utility> +#include <vector> + +#include "arrow/builder.h" +#include "arrow/compute/context.h" +#include "arrow/util/bit-util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" +#include "arrow/util/stl.h" +#include "arrow/visitor_inline.h" + +namespace arrow { +namespace compute { + +using internal::checked_cast; +using internal::checked_pointer_cast; + +template <typename Builder, typename Scalar> +static Status UnsafeAppend(Builder* builder, Scalar&& value) { + builder->UnsafeAppend(std::forward<Scalar>(value)); + return Status::OK(); +} + +// Use BinaryBuilder::UnsafeAppend, but reserve byte storage first +static Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); +} + +// Use StringBuilder::UnsafeAppend, but reserve character storage first +static Status UnsafeAppend(StringBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); +} + +/// \brief visit indices from an IndexSequence while bounds checking +/// +/// \param[in] indices IndexSequence to visit +/// \param[in] values array to bounds check against, if necessary +/// \param[in] vis index visitor, signature must be Status(int64_t index, bool is_valid) +template <bool SomeIndicesNull, bool SomeValuesNull, bool NeverOutOfBounds, + typename IndexSequence, typename Visitor> +Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) { + for (int64_t i = 0; i < indices.length(); ++i) { + auto index_valid = indices.Next(); + if (SomeIndicesNull && !index_valid.second) { + RETURN_NOT_OK(vis(0, false)); + continue; + } + + auto index = index_valid.first; + if (!NeverOutOfBounds) { + if (index < 0 || index >= values.length()) { + return Status::IndexError("take index out of bounds"); + } + } + + bool is_valid = !SomeValuesNull || values.IsValid(index); + RETURN_NOT_OK(vis(index, is_valid)); + } + return Status::OK(); +} + +template <bool SomeIndicesNull, bool SomeValuesNull, typename IndexSequence, + typename Visitor> +Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) { + if (indices.never_out_of_bounds()) { + return VisitIndices<SomeIndicesNull, SomeValuesNull, true>( + indices, values, std::forward<Visitor>(vis)); + } + return VisitIndices<SomeIndicesNull, SomeValuesNull, false>(indices, values, + std::forward<Visitor>(vis)); +} + +template <bool SomeIndicesNull, typename IndexSequence, typename Visitor> +Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) { + if (values.null_count() == 0) { + return VisitIndices<SomeIndicesNull, false>(indices, values, + std::forward<Visitor>(vis)); + } + return VisitIndices<SomeIndicesNull, true>(indices, values, std::forward<Visitor>(vis)); +} + +template <typename IndexSequence, typename Visitor> +Status VisitIndices(IndexSequence indices, const Array& values, Visitor&& vis) { + if (indices.null_count() == 0) { + return VisitIndices<false>(indices, values, std::forward<Visitor>(vis)); + } + return VisitIndices<true>(indices, values, std::forward<Visitor>(vis)); +} + +// Helper class for gathering values from an array +template <typename IndexSequence> +class Taker { + public: + explicit Taker(const std::shared_ptr<DataType>& type) : type_(type) {} + + virtual ~Taker() = default; + + // construct any children, must be called once after construction + virtual Status MakeChildren() { return Status::OK(); } + + // reset this Taker, prepare to gather into an array allocated from pool + // must be called each time the output pool may have changed + virtual Status Init(MemoryPool* pool) = 0; + + // gather elements from an array at the provided indices + virtual Status Take(const Array& values, IndexSequence indices) = 0; + + // assemble an array of all gathered values + virtual Status Finish(std::shared_ptr<Array>*) = 0; + + // factory; the output Taker will support gathering values of the given type + static Status Make(const std::shared_ptr<DataType>& type, std::unique_ptr<Taker>* out); + + static_assert(std::is_literal_type<IndexSequence>::value, + "Index sequences must be literal type"); + + static_assert(std::is_copy_constructible<IndexSequence>::value, + "Index sequences must be copy constructible"); + + static_assert(std::is_same<decltype(std::declval<IndexSequence>().Next()), + std::pair<int64_t, bool>>::value, + "An index sequence must yield pairs of indices:int64_t, validity:bool."); + + static_assert(std::is_same<decltype(std::declval<const IndexSequence>().length()), + int64_t>::value, + "An index sequence must provide its length."); + + static_assert(std::is_same<decltype(std::declval<const IndexSequence>().null_count()), + int64_t>::value, + "An index sequence must provide the number of nulls it will take."); + + static_assert( + std::is_same<decltype(std::declval<const IndexSequence>().never_out_of_bounds()), + bool>::value, + "Index sequences must declare whether bounds checking is necessary"); + + static_assert( + std::is_same<decltype(std::declval<IndexSequence>().set_never_out_of_bounds()), + void>::value, + "An index sequence must support ignoring bounds checking."); + + protected: + template <typename Builder> + Status MakeBuilder(MemoryPool* pool, std::unique_ptr<Builder>* out) { + std::unique_ptr<ArrayBuilder> builder; + RETURN_NOT_OK(arrow::MakeBuilder(pool, type_, &builder)); + out->reset(checked_cast<Builder*>(builder.release())); + return Status::OK(); + } + + std::shared_ptr<DataType> type_; +}; + +// an IndexSequence which yields indices from a specified range +// or yields null for the length of that range +class RangeIndexSequence { + public: + constexpr bool never_out_of_bounds() const { return true; } + void set_never_out_of_bounds() {} + + constexpr RangeIndexSequence() = default; + + RangeIndexSequence(bool is_valid, int64_t offset, int64_t length) + : is_valid_(is_valid), index_(offset), length_(length) {} + + std::pair<int64_t, bool> Next() { return std::make_pair(index_++, is_valid_); } + + int64_t length() const { return length_; } + + int64_t null_count() const { return is_valid_ ? 0 : length_; } + + private: + bool is_valid_ = true; + int64_t index_ = 0, length_ = -1; +}; + +// Default implementation: taking from a simple array into a builder requires only that +// the array supports array.GetView() and the corresponding builder supports +// builder.UnsafeAppend(array.GetView()) +template <typename IndexSequence, typename T> +class TakerImpl : public Taker<IndexSequence> { + public: + using ArrayType = typename TypeTraits<T>::ArrayType; + using BuilderType = typename TypeTraits<T>::BuilderType; + + using Taker<IndexSequence>::Taker; + + Status Init(MemoryPool* pool) override { return this->MakeBuilder(pool, &builder_); } + + Status Take(const Array& values, IndexSequence indices) override { + DCHECK(this->type_->Equals(values.type())); + RETURN_NOT_OK(builder_->Reserve(indices.length())); + return VisitIndices(indices, values, [&](int64_t index, bool is_valid) { + if (!is_valid) { + builder_->UnsafeAppendNull(); + return Status::OK(); + } + auto value = checked_cast<const ArrayType&>(values).GetView(index); + return UnsafeAppend(builder_.get(), value); + }); + } + + Status Finish(std::shared_ptr<Array>* out) override { return builder_->Finish(out); } + + private: + std::unique_ptr<BuilderType> builder_; +}; + +// Gathering from NullArrays is trivial; skip the builder and just +// do bounds checking +template <typename IndexSequence> +class TakerImpl<IndexSequence, NullType> : public Taker<IndexSequence> { + public: + using Taker<IndexSequence>::Taker; + + Status Init(MemoryPool*) override { return Status::OK(); } + + Status Take(const Array& values, IndexSequence indices) override { + DCHECK(this->type_->Equals(values.type())); + + length_ += indices.length(); + + if (indices.never_out_of_bounds()) { + return Status::OK(); + } + + return VisitIndices(indices, values, [](int64_t, bool) { return Status::OK(); }); + } + + Status Finish(std::shared_ptr<Array>* out) override { + out->reset(new NullArray(length_)); + return Status::OK(); + } + + private: + int64_t length_ = 0; +}; + +template <typename IndexSequence> +class TakerImpl<IndexSequence, ListType> : public Taker<IndexSequence> { + public: + using Taker<IndexSequence>::Taker; + + Status MakeChildren() override { + const auto& list_type = checked_cast<const ListType&>(*this->type_); + return Taker<RangeIndexSequence>::Make(list_type.value_type(), &value_taker_); + } + + Status Init(MemoryPool* pool) override { + null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool)); + offset_builder_.reset(new TypedBufferBuilder<int32_t>(pool)); + RETURN_NOT_OK(offset_builder_->Append(0)); + return value_taker_->Init(pool); + } + + Status Take(const Array& values, IndexSequence indices) override { + DCHECK(this->type_->Equals(values.type())); + + const auto& list_array = checked_cast<const ListArray&>(values); + + RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length())); + RETURN_NOT_OK(offset_builder_->Reserve(indices.length())); + + int32_t offset = offset_builder_->data()[offset_builder_->length() - 1]; + return VisitIndices(indices, values, [&](int64_t index, bool is_valid) { + null_bitmap_builder_->UnsafeAppend(is_valid); + + if (is_valid) { + offset += list_array.value_length(index); + RangeIndexSequence value_indices(true, list_array.value_offset(index), + list_array.value_length(index)); + RETURN_NOT_OK(value_taker_->Take(*list_array.values(), value_indices)); + } + + offset_builder_->UnsafeAppend(offset); + return Status::OK(); + }); + } + + Status Finish(std::shared_ptr<Array>* out) override { return FinishAs<ListArray>(out); } + + protected: + // this added method is provided for use by TakerImpl<IndexSequence, MapType>, + // which needs to construct a MapArray rather than a ListArray + template <typename T> + Status FinishAs(std::shared_ptr<Array>* out) { + auto null_count = null_bitmap_builder_->false_count(); + auto length = null_bitmap_builder_->length(); + + std::shared_ptr<Buffer> offsets, null_bitmap; + RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap)); + RETURN_NOT_OK(offset_builder_->Finish(&offsets)); + + std::shared_ptr<Array> taken_values; + RETURN_NOT_OK(value_taker_->Finish(&taken_values)); + + out->reset( + new T(this->type_, length, offsets, taken_values, null_bitmap, null_count)); + return Status::OK(); + } + + std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_; + std::unique_ptr<TypedBufferBuilder<int32_t>> offset_builder_; + std::unique_ptr<Taker<RangeIndexSequence>> value_taker_; +}; + +template <typename IndexSequence> +class TakerImpl<IndexSequence, MapType> : public TakerImpl<IndexSequence, ListType> { + public: + using TakerImpl<IndexSequence, ListType>::TakerImpl; + + Status Finish(std::shared_ptr<Array>* out) override { + return this->template FinishAs<MapArray>(out); + } +}; + +template <typename IndexSequence> +class TakerImpl<IndexSequence, FixedSizeListType> : public Taker<IndexSequence> { + public: + using Taker<IndexSequence>::Taker; + + Status MakeChildren() override { + const auto& list_type = checked_cast<const FixedSizeListType&>(*this->type_); + return Taker<RangeIndexSequence>::Make(list_type.value_type(), &value_taker_); + } + + Status Init(MemoryPool* pool) override { + null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool)); + return value_taker_->Init(pool); + } + + Status Take(const Array& values, IndexSequence indices) override { + DCHECK(this->type_->Equals(values.type())); + + const auto& list_array = checked_cast<const FixedSizeListArray&>(values); + auto list_size = list_array.list_type()->list_size(); + + RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length())); + return VisitIndices(indices, values, [&](int64_t index, bool is_valid) { + null_bitmap_builder_->UnsafeAppend(is_valid); + + // for FixedSizeList, null lists are not empty (they also span a segment of + // list_size in the child data), so we must append to value_taker_ even if !is_valid + RangeIndexSequence value_indices(is_valid, list_array.value_offset(index), + list_size); + return value_taker_->Take(*list_array.values(), value_indices); + }); + } + + Status Finish(std::shared_ptr<Array>* out) override { + auto null_count = null_bitmap_builder_->false_count(); + auto length = null_bitmap_builder_->length(); + + std::shared_ptr<Buffer> null_bitmap; + RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap)); + + std::shared_ptr<Array> taken_values; + RETURN_NOT_OK(value_taker_->Finish(&taken_values)); + + out->reset(new FixedSizeListArray(this->type_, length, taken_values, null_bitmap, + null_count)); + return Status::OK(); + } + + protected: + std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_; + std::unique_ptr<Taker<RangeIndexSequence>> value_taker_; +}; + +template <typename IndexSequence> +class TakerImpl<IndexSequence, StructType> : public Taker<IndexSequence> { + public: + using Taker<IndexSequence>::Taker; + + Status MakeChildren() override { + children_.resize(this->type_->num_children()); + for (int i = 0; i < this->type_->num_children(); ++i) { + RETURN_NOT_OK( + Taker<IndexSequence>::Make(this->type_->child(i)->type(), &children_[i])); + } + return Status::OK(); + } + + Status Init(MemoryPool* pool) override { + null_bitmap_builder_.reset(new TypedBufferBuilder<bool>(pool)); + for (int i = 0; i < this->type_->num_children(); ++i) { + RETURN_NOT_OK(children_[i]->Init(pool)); + } + return Status::OK(); + } + + Status Take(const Array& values, IndexSequence indices) override { + DCHECK(this->type_->Equals(values.type())); + + RETURN_NOT_OK(null_bitmap_builder_->Reserve(indices.length())); + RETURN_NOT_OK(VisitIndices(indices, values, [&](int64_t, bool is_valid) { + null_bitmap_builder_->UnsafeAppend(is_valid); + return Status::OK(); + })); + + // bounds checking was done while appending to the null bitmap + indices.set_never_out_of_bounds(); + + const auto& struct_array = checked_cast<const StructArray&>(values); + for (int i = 0; i < this->type_->num_children(); ++i) { + RETURN_NOT_OK(children_[i]->Take(*struct_array.field(i), indices)); + } + return Status::OK(); + } + + Status Finish(std::shared_ptr<Array>* out) override { + auto null_count = null_bitmap_builder_->false_count(); + auto length = null_bitmap_builder_->length(); + std::shared_ptr<Buffer> null_bitmap; + RETURN_NOT_OK(null_bitmap_builder_->Finish(&null_bitmap)); + + ArrayVector fields(this->type_->num_children()); + for (int i = 0; i < this->type_->num_children(); ++i) { + RETURN_NOT_OK(children_[i]->Finish(&fields[i])); + } + + out->reset( + new StructArray(this->type_, length, std::move(fields), null_bitmap, null_count)); + return Status::OK(); + } + + protected: + std::unique_ptr<TypedBufferBuilder<bool>> null_bitmap_builder_; + std::vector<std::unique_ptr<Taker<IndexSequence>>> children_; +}; + +// taking from a DictionaryArray is accomplished by taking from its indices +template <typename IndexSequence> +class TakerImpl<IndexSequence, DictionaryType> : public Taker<IndexSequence> { + public: + using Taker<IndexSequence>::Taker; + + Status MakeChildren() override { + const auto& dict_type = checked_cast<const DictionaryType&>(*this->type_); + return Taker<IndexSequence>::Make(dict_type.index_type(), &index_taker_); + } + + Status Init(MemoryPool* pool) override { + dictionary_ = nullptr; + return index_taker_->Init(pool); + } + + Status Take(const Array& values, IndexSequence indices) override { + DCHECK(this->type_->Equals(values.type())); + const auto& dict_array = checked_cast<const DictionaryArray&>(values); + + if (dictionary_ != nullptr && dictionary_ != dict_array.dictionary()) { + return Status::NotImplemented( + "taking from DictionaryArrays with different dictionaries"); + } else { + dictionary_ = dict_array.dictionary(); + } + return index_taker_->Take(*dict_array.indices(), indices); + } + + Status Finish(std::shared_ptr<Array>* out) override { + std::shared_ptr<Array> taken_indices; + RETURN_NOT_OK(index_taker_->Finish(&taken_indices)); + out->reset(new DictionaryArray(this->type_, taken_indices, dictionary_)); + return Status::OK(); + } + + protected: + std::shared_ptr<Array> dictionary_; + std::unique_ptr<Taker<IndexSequence>> index_taker_; +}; + +// taking from an ExtensionArray is accomplished by taking from its storage +template <typename IndexSequence> +class TakerImpl<IndexSequence, ExtensionType> : public Taker<IndexSequence> { + public: + using Taker<IndexSequence>::Taker; + + Status MakeChildren() override { + const auto& ext_type = checked_cast<const ExtensionType&>(*this->type_); + return Taker<IndexSequence>::Make(ext_type.storage_type(), &storage_taker_); + } + + Status Init(MemoryPool* pool) override { return storage_taker_->Init(pool); } + + Status Take(const Array& values, IndexSequence indices) override { + DCHECK(this->type_->Equals(values.type())); + const auto& ext_array = checked_cast<const ExtensionArray&>(values); + return storage_taker_->Take(*ext_array.storage(), indices); + } + + Status Finish(std::shared_ptr<Array>* out) override { + std::shared_ptr<Array> taken_storage; + RETURN_NOT_OK(storage_taker_->Finish(&taken_storage)); + out->reset(new ExtensionArray(this->type_, taken_storage)); + return Status::OK(); + } + + protected: + std::unique_ptr<Taker<IndexSequence>> storage_taker_; +}; + +template <typename IndexSequence> +struct TakerMakeImpl { + template <typename T> + Status Visit(const T&) { + out_->reset(new TakerImpl<IndexSequence, T>(type_)); + return (*out_)->MakeChildren(); + } + + Status Visit(const UnionType& t) { + return Status::NotImplemented("gathering values of type ", t); + } + + std::shared_ptr<DataType> type_; + std::unique_ptr<Taker<IndexSequence>>* out_; +}; + +template <typename IndexSequence> +Status Taker<IndexSequence>::Make(const std::shared_ptr<DataType>& type, + std::unique_ptr<Taker>* out) { + TakerMakeImpl<IndexSequence> visitor{type, out}; + return VisitTypeInline(*type, &visitor); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index c61aeda..da5e0c0 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -29,31 +29,40 @@ namespace arrow { namespace compute { +using internal::checked_cast; +using internal::checked_pointer_cast; using util::string_view; +constexpr auto kSeed = 0x0ff1ce; + template <typename ArrowType> class TestTakeKernel : public ComputeFixture, public TestBase { protected: void AssertTakeArrays(const std::shared_ptr<Array>& values, - const std::shared_ptr<Array>& indices, TakeOptions options, + const std::shared_ptr<Array>& indices, const std::shared_ptr<Array>& expected) { std::shared_ptr<Array> actual; + TakeOptions options; ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, options, &actual)); + ASSERT_OK(ValidateArray(*actual)); AssertArraysEqual(*expected, *actual); } + void AssertTake(const std::shared_ptr<DataType>& type, const std::string& values, - const std::string& indices, TakeOptions options, - const std::string& expected) { + const std::string& indices, const std::string& expected) { std::shared_ptr<Array> actual; for (auto index_type : {int8(), uint32()}) { - ASSERT_OK(this->Take(type, values, index_type, indices, options, &actual)); + ASSERT_OK(this->Take(type, values, index_type, indices, &actual)); + ASSERT_OK(ValidateArray(*actual)); AssertArraysEqual(*ArrayFromJSON(type, expected), *actual); } } + Status Take(const std::shared_ptr<DataType>& type, const std::string& values, const std::shared_ptr<DataType>& index_type, const std::string& indices, - TakeOptions options, std::shared_ptr<Array>* out) { + std::shared_ptr<Array>* out) { + TakeOptions options; return arrow::compute::Take(&this->ctx_, *ArrayFromJSON(type, values), *ArrayFromJSON(index_type, indices), options, out); } @@ -62,82 +71,123 @@ class TestTakeKernel : public ComputeFixture, public TestBase { class TestTakeKernelWithNull : public TestTakeKernel<NullType> { protected: void AssertTake(const std::string& values, const std::string& indices, - TakeOptions options, const std::string& expected) { - TestTakeKernel<NullType>::AssertTake(utf8(), values, indices, options, expected); + const std::string& expected) { + TestTakeKernel<NullType>::AssertTake(null(), values, indices, expected); } }; TEST_F(TestTakeKernelWithNull, TakeNull) { - TakeOptions options; - this->AssertTake("[null, null, null]", "[0, 1, 0]", options, "[null, null, null]"); + this->AssertTake("[null, null, null]", "[0, 1, 0]", "[null, null, null]"); std::shared_ptr<Array> arr; - ASSERT_RAISES(IndexError, this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]", - options, &arr)); + ASSERT_RAISES(IndexError, + this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, + this->Take(boolean(), "[null, null, null]", int8(), "[0, -1, 0]", &arr)); } TEST_F(TestTakeKernelWithNull, InvalidIndexType) { - TakeOptions options; std::shared_ptr<Array> arr; ASSERT_RAISES(TypeError, this->Take(null(), "[null, null, null]", float32(), - "[0.0, 1.0, 0.1]", options, &arr)); + "[0.0, 1.0, 0.1]", &arr)); } class TestTakeKernelWithBoolean : public TestTakeKernel<BooleanType> { protected: void AssertTake(const std::string& values, const std::string& indices, - TakeOptions options, const std::string& expected) { - TestTakeKernel<BooleanType>::AssertTake(boolean(), values, indices, options, - expected); + const std::string& expected) { + TestTakeKernel<BooleanType>::AssertTake(boolean(), values, indices, expected); } }; TEST_F(TestTakeKernelWithBoolean, TakeBoolean) { - TakeOptions options; - this->AssertTake("[true, false, true]", "[0, 1, 0]", options, "[true, false, true]"); - this->AssertTake("[null, false, true]", "[0, 1, 0]", options, "[null, false, null]"); - this->AssertTake("[true, false, true]", "[null, 1, 0]", options, "[null, false, true]"); + this->AssertTake("[7, 8, 9]", "[]", "[]"); + this->AssertTake("[true, false, true]", "[0, 1, 0]", "[true, false, true]"); + this->AssertTake("[null, false, true]", "[0, 1, 0]", "[null, false, null]"); + this->AssertTake("[true, false, true]", "[null, 1, 0]", "[null, false, true]"); std::shared_ptr<Array> arr; - ASSERT_RAISES(IndexError, this->Take(boolean(), "[true, false, true]", int8(), - "[0, 9, 0]", options, &arr)); + ASSERT_RAISES(IndexError, + this->Take(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, + this->Take(boolean(), "[true, false, true]", int8(), "[0, -1, 0]", &arr)); } template <typename ArrowType> class TestTakeKernelWithNumeric : public TestTakeKernel<ArrowType> { protected: void AssertTake(const std::string& values, const std::string& indices, - TakeOptions options, const std::string& expected) { - TestTakeKernel<ArrowType>::AssertTake(type_singleton(), values, indices, options, - expected); + const std::string& expected) { + TestTakeKernel<ArrowType>::AssertTake(type_singleton(), values, indices, expected); } + std::shared_ptr<DataType> type_singleton() { return TypeTraits<ArrowType>::type_singleton(); } + + void ValidateTake(const std::shared_ptr<Array>& values, + const std::shared_ptr<Array>& indices_boxed) { + std::shared_ptr<Array> taken; + TakeOptions options; + ASSERT_OK( + arrow::compute::Take(&this->ctx_, *values, *indices_boxed, options, &taken)); + ASSERT_OK(ValidateArray(*taken)); + ASSERT_EQ(indices_boxed->length(), taken->length()); + + ASSERT_EQ(indices_boxed->type_id(), Type::INT32); + auto indices = checked_pointer_cast<Int32Array>(indices_boxed); + for (int64_t i = 0; i < indices->length(); ++i) { + if (indices->IsNull(i)) { + ASSERT_TRUE(taken->IsNull(i)); + continue; + } + int32_t taken_index = indices->Value(i); + ASSERT_TRUE(values->RangeEquals(taken_index, taken_index + 1, i, taken)); + } + } }; TYPED_TEST_CASE(TestTakeKernelWithNumeric, NumericArrowTypes); TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { - TakeOptions options; - this->AssertTake("[7, 8, 9]", "[0, 1, 0]", options, "[7, 8, 7]"); - this->AssertTake("[null, 8, 9]", "[0, 1, 0]", options, "[null, 8, null]"); - this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]"); - this->AssertTake("[null, 8, 9]", "[]", options, "[]"); + this->AssertTake("[7, 8, 9]", "[]", "[]"); + this->AssertTake("[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]"); + this->AssertTake("[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]"); + this->AssertTake("[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]"); + this->AssertTake("[null, 8, 9]", "[]", "[]"); + this->AssertTake("[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]"); std::shared_ptr<Array> arr; ASSERT_RAISES(IndexError, this->Take(this->type_singleton(), "[7, 8, 9]", int8(), - "[0, 9, 0]", options, &arr)); + "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, this->Take(this->type_singleton(), "[7, 8, 9]", int8(), + "[0, -1, 0]", &arr)); +} + +TYPED_TEST(TestTakeKernelWithNumeric, TakeRandomNumeric) { + auto rand = random::RandomArrayGenerator(kSeed); + for (size_t i = 3; i < 8; i++) { + const int64_t length = static_cast<int64_t>(1ULL << i); + for (size_t j = 0; j < 13; j++) { + const int64_t indices_length = static_cast<int64_t>(1ULL << j); + for (auto null_probability : {0.0, 0.01, 0.25, 1.0}) { + auto values = rand.Numeric<TypeParam>(length, 0, 127, null_probability); + auto max_index = static_cast<int32_t>(length - 1); + auto filter = rand.Int32(indices_length, 0, max_index, null_probability); + this->ValidateTake(values, filter); + } + } + } } class TestTakeKernelWithString : public TestTakeKernel<StringType> { protected: void AssertTake(const std::string& values, const std::string& indices, - TakeOptions options, const std::string& expected) { - TestTakeKernel<StringType>::AssertTake(utf8(), values, indices, options, expected); + const std::string& expected) { + TestTakeKernel<StringType>::AssertTake(utf8(), values, indices, expected); } void AssertTakeDictionary(const std::string& dictionary_values, const std::string& dictionary_indices, - const std::string& indices, TakeOptions options, + const std::string& indices, const std::string& expected_indices) { auto dict = ArrayFromJSON(utf8(), dictionary_values); auto type = dictionary(int8(), utf8()); @@ -147,28 +197,272 @@ class TestTakeKernelWithString : public TestTakeKernel<StringType> { ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict, &expected)); auto take_indices = ArrayFromJSON(int8(), indices); - this->AssertTakeArrays(values, take_indices, options, expected); + this->AssertTakeArrays(values, take_indices, expected); } }; TEST_F(TestTakeKernelWithString, TakeString) { - TakeOptions options; - this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", options, R"(["a", "b", "a"])"); - this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", options, "[null, \"b\", null]"); - this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", options, R"([null, "b", "a"])"); + this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])"); + this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]"); + this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])"); std::shared_ptr<Array> arr; - ASSERT_RAISES(IndexError, this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]", - options, &arr)); + ASSERT_RAISES(IndexError, + this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, this->Take(utf8(), R"(["a", "b", null, "ddd", "ee"])", + int64(), "[2, 5]", &arr)); } TEST_F(TestTakeKernelWithString, TakeDictionary) { - TakeOptions options; auto dict = R"(["a", "b", "c", "d", "e"])"; - this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", options, "[3, 4, 3]"); - this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", options, - "[null, 4, null]"); - this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", options, "[null, 4, 3]"); + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]"); + this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]"); + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]"); +} + +class TestTakeKernelWithList : public TestTakeKernel<ListType> {}; + +TEST_F(TestTakeKernelWithList, TakeListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + this->AssertTake(list(int32()), list_json, "[]", "[]"); + this->AssertTake(list(int32()), list_json, "[3, 2, 1]", "[[3], null, [1,2]]"); + this->AssertTake(list(int32()), list_json, "[null, 3, 0]", "[null, [3], []]"); + this->AssertTake(list(int32()), list_json, "[null, null]", "[null, null]"); + this->AssertTake(list(int32()), list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]"); + this->AssertTake(list(int32()), list_json, "[0, 1, 2, 3]", list_json); + this->AssertTake(list(int32()), list_json, "[0, 0, 0, 0, 0, 0, 1]", + "[[], [], [], [], [], [], [1, 2]]"); +} + +TEST_F(TestTakeKernelWithList, TakeListListInt32) { + std::string list_json = R"([ + [], + [[1], [2, null, 2], []], + null, + [[3, null], null] + ])"; + auto type = list(list(int32())); + this->AssertTake(type, list_json, "[]", "[]"); + this->AssertTake(type, list_json, "[3, 2, 1]", R"([ + [[3, null], null], + null, + [[1], [2, null, 2], []] + ])"); + this->AssertTake(type, list_json, "[null, 3, 0]", R"([ + null, + [[3, null], null], + [] + ])"); + this->AssertTake(type, list_json, "[null, null]", "[null, null]"); + this->AssertTake(type, list_json, "[3, 0, 0, 3]", + "[[[3, null], null], [], [], [[3, null], null]]"); + this->AssertTake(type, list_json, "[0, 1, 2, 3]", list_json); + this->AssertTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]", + "[[], [], [], [], [], [], [[1], [2, null, 2], []]]"); +} + +class TestTakeKernelWithFixedSizeList : public TestTakeKernel<FixedSizeListType> {}; + +TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { + std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; + this->AssertTake(fixed_size_list(int32(), 3), list_json, "[]", "[]"); + this->AssertTake(fixed_size_list(int32(), 3), list_json, "[3, 2, 1]", + "[[7, 8, null], [4, 5, 6], [1, null, 3]]"); + this->AssertTake(fixed_size_list(int32(), 3), list_json, "[null, 2, 0]", + "[null, [4, 5, 6], null]"); + this->AssertTake(fixed_size_list(int32(), 3), list_json, "[null, null]", + "[null, null]"); + this->AssertTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]", + "[[7, 8, null], null, null, [7, 8, null]]"); + this->AssertTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json); + this->AssertTake( + fixed_size_list(int32(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]", + "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, null, 3]]"); +} + +class TestTakeKernelWithMap : public TestTakeKernel<MapType> {}; + +TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { + std::string map_json = R"([ + [["joe", 0], ["mark", null]], + null, + [["cap", 8]], + [] + ])"; + this->AssertTake(map(utf8(), int32()), map_json, "[]", "[]"); + this->AssertTake(map(utf8(), int32()), map_json, "[3, 1, 3, 1, 3]", + "[[], null, [], null, []]"); + this->AssertTake(map(utf8(), int32()), map_json, "[2, 1, null]", R"([ + [["cap", 8]], + null, + null + ])"); + this->AssertTake(map(utf8(), int32()), map_json, "[2, 1, 0]", R"([ + [["cap", 8]], + null, + [["joe", 0], ["mark", null]] + ])"); + this->AssertTake(map(utf8(), int32()), map_json, "[0, 1, 2, 3]", map_json); + this->AssertTake(map(utf8(), int32()), map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([ + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [] + ])"); +} + +class TestTakeKernelWithStruct : public TestTakeKernel<StructType> {}; + +TEST_F(TestTakeKernelWithStruct, TakeStruct) { + auto struct_type = struct_({field("a", int32()), field("b", utf8())}); + auto struct_json = R"([ + null, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + this->AssertTake(struct_type, struct_json, "[]", "[]"); + this->AssertTake(struct_type, struct_json, "[3, 1, 3, 1, 3]", R"([ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"} + ])"); + this->AssertTake(struct_type, struct_json, "[3, 1, 0]", R"([ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + null + ])"); + this->AssertTake(struct_type, struct_json, "[0, 1, 2, 3]", struct_json); + this->AssertTake(struct_type, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + null, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"} + ])"); +} + +class TestPermutationsWithTake : public ComputeFixture, public TestBase { + protected: + void Take(const Int16Array& values, const Int16Array& indices, + std::shared_ptr<Int16Array>* out) { + TakeOptions options; + std::shared_ptr<Array> boxed_out; + ASSERT_OK(arrow::compute::Take(&this->ctx_, values, indices, options, &boxed_out)); + ASSERT_OK(ValidateArray(*boxed_out)); + *out = checked_pointer_cast<Int16Array>(std::move(boxed_out)); + } + + std::shared_ptr<Int16Array> Take(const Int16Array& values, const Int16Array& indices) { + std::shared_ptr<Int16Array> out; + Take(values, indices, &out); + return out; + } + + std::shared_ptr<Int16Array> TakeN(uint64_t n, std::shared_ptr<Int16Array> array) { + auto power_of_2 = array; + array = Identity(array->length()); + while (n != 0) { + if (n & 1) { + array = Take(*array, *power_of_2); + } + power_of_2 = Take(*power_of_2, *power_of_2); + n >>= 1; + } + return array; + } + + template <typename Rng> + void Shuffle(const Int16Array& array, Rng& gen, std::shared_ptr<Int16Array>* shuffled) { + auto byte_length = array.length() * sizeof(int16_t); + std::shared_ptr<Buffer> data; + ASSERT_OK(array.values()->Copy(0, byte_length, &data)); + auto mutable_data = reinterpret_cast<int16_t*>(data->mutable_data()); + std::shuffle(mutable_data, mutable_data + array.length(), gen); + shuffled->reset(new Int16Array(array.length(), data)); + } + + template <typename Rng> + std::shared_ptr<Int16Array> Shuffle(const Int16Array& array, Rng& gen) { + std::shared_ptr<Int16Array> out; + Shuffle(array, gen, &out); + return out; + } + + void Identity(int64_t length, std::shared_ptr<Int16Array>* identity) { + Int16Builder identity_builder; + ASSERT_OK(identity_builder.Resize(length)); + for (int16_t i = 0; i < length; ++i) { + identity_builder.UnsafeAppend(i); + } + ASSERT_OK(identity_builder.Finish(identity)); + } + + std::shared_ptr<Int16Array> Identity(int64_t length) { + std::shared_ptr<Int16Array> out; + Identity(length, &out); + return out; + } + + std::shared_ptr<Int16Array> Inverse(const std::shared_ptr<Int16Array>& permutation) { + auto length = static_cast<int16_t>(permutation->length()); + + std::vector<bool> cycle_lengths(length + 1, false); + auto permutation_to_the_i = permutation; + for (int16_t cycle_length = 1; cycle_length <= length; ++cycle_length) { + cycle_lengths[cycle_length] = HasTrivialCycle(*permutation_to_the_i); + permutation_to_the_i = Take(*permutation, *permutation_to_the_i); + } + + uint64_t cycle_to_identity_length = 1; + for (int16_t cycle_length = length; cycle_length > 1; --cycle_length) { + if (!cycle_lengths[cycle_length]) { + continue; + } + if (cycle_to_identity_length % cycle_length == 0) { + continue; + } + if (cycle_to_identity_length > + std::numeric_limits<uint64_t>::max() / cycle_length) { + // overflow, can't compute Inverse + return nullptr; + } + cycle_to_identity_length *= cycle_length; + } + + return TakeN(cycle_to_identity_length - 1, permutation); + } + + bool HasTrivialCycle(const Int16Array& permutation) { + for (int64_t i = 0; i < permutation.length(); ++i) { + if (permutation.Value(i) == static_cast<int16_t>(i)) { + return true; + } + } + return false; + } +}; + +TEST_F(TestPermutationsWithTake, InvertPermutation) { + for (int seed : {0, kSeed, kSeed * 2 - 1}) { + std::default_random_engine gen(seed); + for (int16_t length = 0; length < 1 << 10; ++length) { + auto identity = Identity(length); + auto permutation = Shuffle(*identity, gen); + auto inverse = Inverse(permutation); + if (inverse == nullptr) { + break; + } + ASSERT_TRUE(Take(*inverse, *permutation)->Equals(identity)); + } + } } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 17b0540..6ed9111 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +#include <limits> #include <memory> #include <utility> -#include "arrow/builder.h" #include "arrow/compute/context.h" +#include "arrow/compute/kernels/take-internal.h" #include "arrow/compute/kernels/take.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" @@ -30,200 +31,107 @@ namespace compute { using internal::checked_cast; -Status Take(FunctionContext* context, const Array& values, const Array& indices, - const TakeOptions& options, std::shared_ptr<Array>* out) { - Datum out_datum; - RETURN_NOT_OK( - Take(context, Datum(values.data()), Datum(indices.data()), options, &out_datum)); - *out = out_datum.make_array(); - return Status::OK(); -} - -Status Take(FunctionContext* context, const Datum& values, const Datum& indices, - const TakeOptions& options, Datum* out) { - TakeKernel kernel(values.type(), options); - RETURN_NOT_OK(kernel.Call(context, values, indices, out)); - return Status::OK(); -} - -struct TakeParameters { - FunctionContext* context; - std::shared_ptr<Array> values, indices; - TakeOptions options; - std::shared_ptr<Array>* out; -}; - -template <typename Builder, typename Scalar> -Status UnsafeAppend(Builder* builder, Scalar&& value) { - builder->UnsafeAppend(std::forward<Scalar>(value)); - return Status::OK(); -} - -Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) { - RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size()))); - builder->UnsafeAppend(value); - return Status::OK(); -} - -Status UnsafeAppend(StringBuilder* builder, util::string_view value) { - RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size()))); - builder->UnsafeAppend(value); - return Status::OK(); -} - -template <bool AllValuesValid, bool AllIndicesValid, typename ValueArray, - typename IndexArray, typename OutBuilder> -Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& indices, - OutBuilder* builder) { - auto raw_indices = indices.raw_values(); - for (int64_t i = 0; i < indices.length(); ++i) { - if (!AllIndicesValid && indices.IsNull(i)) { - builder->UnsafeAppendNull(); - continue; - } - auto index = static_cast<int64_t>(raw_indices[i]); - if (index < 0 || index >= values.length()) { - return Status::IndexError("take index out of bounds"); - } - if (!AllValuesValid && values.IsNull(index)) { - builder->UnsafeAppendNull(); - continue; - } - RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(index))); - } - return Status::OK(); -} - -template <bool AllValuesValid, typename ValueArray, typename IndexArray, - typename OutBuilder> -Status UnpackIndicesNullCount(FunctionContext* context, const ValueArray& values, - const IndexArray& indices, OutBuilder* builder) { - if (indices.null_count() == 0) { - return TakeImpl<AllValuesValid, true>(context, values, indices, builder); - } - return TakeImpl<AllValuesValid, false>(context, values, indices, builder); -} - -template <typename ValueArray, typename IndexArray, typename OutBuilder> -Status UnpackValuesNullCount(FunctionContext* context, const ValueArray& values, - const IndexArray& indices, OutBuilder* builder) { - if (values.null_count() == 0) { - return UnpackIndicesNullCount<true>(context, values, indices, builder); - } - return UnpackIndicesNullCount<false>(context, values, indices, builder); -} - +// an IndexSequence which yields the values of an Array of integers template <typename IndexType> -struct UnpackValues { - using IndexArrayRef = const typename TypeTraits<IndexType>::ArrayType&; - - template <typename ValueType> - Status Visit(const ValueType&) { - using ValueArrayRef = const typename TypeTraits<ValueType>::ArrayType&; - using OutBuilder = typename TypeTraits<ValueType>::BuilderType; - IndexArrayRef indices = checked_cast<IndexArrayRef>(*params_.indices); - ValueArrayRef values = checked_cast<ValueArrayRef>(*params_.values); - std::unique_ptr<ArrayBuilder> builder; - RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder)); - RETURN_NOT_OK(builder->Reserve(indices.length())); - RETURN_NOT_OK(UnpackValuesNullCount(params_.context, values, indices, - checked_cast<OutBuilder*>(builder.get()))); - return builder->Finish(params_.out); - } +class ArrayIndexSequence { + public: + bool never_out_of_bounds() const { return never_out_of_bounds_; } + void set_never_out_of_bounds() { never_out_of_bounds_ = true; } - Status Visit(const NullType& t) { - auto indices_length = params_.indices->length(); - if (indices_length != 0) { - auto indices = checked_cast<IndexArrayRef>(*params_.indices).raw_values(); - auto minmax = std::minmax_element(indices, indices + indices_length); - auto min = static_cast<int64_t>(*minmax.first); - auto max = static_cast<int64_t>(*minmax.second); - if (min < 0 || max >= params_.values->length()) { - return Status::IndexError("take index out of bounds"); - } - } - params_.out->reset(new NullArray(indices_length)); - return Status::OK(); - } + constexpr ArrayIndexSequence() = default; - Status Visit(const DictionaryType& t) { - std::shared_ptr<Array> taken_indices; - const auto& values = internal::checked_cast<const DictionaryArray&>(*params_.values); - { - // To take from a dictionary, apply the current kernel to the dictionary's - // indices. (Use UnpackValues<IndexType> since IndexType is already unpacked) - auto indices = values.indices(); - TakeParameters params = params_; - params.values = indices; - params.out = &taken_indices; - UnpackValues<IndexType> unpack = {params}; - RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack)); + explicit ArrayIndexSequence(const Array& indices) + : indices_(&checked_cast<const NumericArray<IndexType>&>(indices)) {} + + std::pair<int64_t, bool> Next() { + if (indices_->IsNull(index_)) { + ++index_; + return std::make_pair(-1, false); } - // create output dictionary from taken indices - *params_.out = std::make_shared<DictionaryArray>(values.type(), taken_indices, - values.dictionary()); - return Status::OK(); + return std::make_pair(indices_->Value(index_++), true); } - Status Visit(const ExtensionType& t) { - // XXX can we just take from its storage? - return Status::NotImplemented("gathering values of type ", t); - } + int64_t length() const { return indices_->length(); } - Status Visit(const UnionType& t) { - return Status::NotImplemented("gathering values of type ", t); - } + int64_t null_count() const { return indices_->null_count(); } - Status Visit(const ListType& t) { - return Status::NotImplemented("gathering values of type ", t); - } + private: + const NumericArray<IndexType>* indices_ = nullptr; + int64_t index_ = 0; + bool never_out_of_bounds_ = false; +}; - Status Visit(const MapType& t) { - return Status::NotImplemented("gathering values of type ", t); - } +template <typename IndexType> +class TakeKernelImpl : public TakeKernel { + public: + explicit TakeKernelImpl(const std::shared_ptr<DataType>& value_type) + : TakeKernel(value_type) {} - Status Visit(const FixedSizeListType& t) { - return Status::NotImplemented("gathering values of type ", t); + Status Init() { + return Taker<ArrayIndexSequence<IndexType>>::Make(this->type_, &taker_); } - Status Visit(const StructType& t) { - return Status::NotImplemented("gathering values of type ", t); + Status Take(FunctionContext* ctx, const Array& values, const Array& indices_array, + std::shared_ptr<Array>* out) override { + RETURN_NOT_OK(taker_->Init(ctx->memory_pool())); + RETURN_NOT_OK(taker_->Take(values, ArrayIndexSequence<IndexType>(indices_array))); + return taker_->Finish(out); } - const TakeParameters& params_; + std::unique_ptr<Taker<ArrayIndexSequence<IndexType>>> taker_; }; struct UnpackIndices { template <typename IndexType> enable_if_integer<IndexType, Status> Visit(const IndexType&) { - UnpackValues<IndexType> unpack = {params_}; - return VisitTypeInline(*params_.values->type(), &unpack); + auto out = new TakeKernelImpl<IndexType>(value_type_); + out_->reset(out); + return out->Init(); } Status Visit(const DataType& other) { return Status::TypeError("index type not supported: ", other); } - const TakeParameters& params_; + std::shared_ptr<DataType> value_type_; + std::unique_ptr<TakeKernel>* out_; }; +Status TakeKernel::Make(const std::shared_ptr<DataType>& value_type, + const std::shared_ptr<DataType>& index_type, + std::unique_ptr<TakeKernel>* out) { + UnpackIndices visitor{value_type, out}; + return VisitTypeInline(*index_type, &visitor); +} + Status TakeKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& indices, Datum* out) { if (!values.is_array() || !indices.is_array()) { return Status::Invalid("TakeKernel expects array values and indices"); } + auto values_array = values.make_array(); + auto indices_array = indices.make_array(); std::shared_ptr<Array> out_array; - TakeParameters params; - params.context = ctx; - params.values = values.make_array(); - params.indices = indices.make_array(); - params.options = options_; - params.out = &out_array; - UnpackIndices unpack = {params}; - RETURN_NOT_OK(VisitTypeInline(*indices.type(), &unpack)); + RETURN_NOT_OK(Take(ctx, *values_array, *indices_array, &out_array)); *out = Datum(out_array); return Status::OK(); } +Status Take(FunctionContext* ctx, const Array& values, const Array& indices, + const TakeOptions& options, std::shared_ptr<Array>* out) { + Datum out_datum; + RETURN_NOT_OK( + Take(ctx, Datum(values.data()), Datum(indices.data()), options, &out_datum)); + *out = out_datum.make_array(); + return Status::OK(); +} + +Status Take(FunctionContext* ctx, const Datum& values, const Datum& indices, + const TakeOptions& options, Datum* out) { + std::unique_ptr<TakeKernel> kernel; + RETURN_NOT_OK(TakeKernel::Make(values.type(), indices.type(), &kernel)); + return kernel->Call(ctx, values, indices, out); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h index 3aa5ed5..f064b72 100644 --- a/cpp/src/arrow/compute/kernels/take.h +++ b/cpp/src/arrow/compute/kernels/take.h @@ -44,40 +44,58 @@ struct ARROW_EXPORT TakeOptions {}; /// = [values[2], values[1], null, values[3]] /// = ["c", "b", null, null] /// -/// \param[in] context the FunctionContext +/// \param[in] ctx the FunctionContext /// \param[in] values array from which to take /// \param[in] indices which values to take /// \param[in] options options /// \param[out] out resulting array ARROW_EXPORT -Status Take(FunctionContext* context, const Array& values, const Array& indices, +Status Take(FunctionContext* ctx, const Array& values, const Array& indices, const TakeOptions& options, std::shared_ptr<Array>* out); /// \brief Take from an array of values at indices in another array /// -/// \param[in] context the FunctionContext +/// \param[in] ctx the FunctionContext /// \param[in] values datum from which to take /// \param[in] indices which values to take /// \param[in] options options /// \param[out] out resulting datum ARROW_EXPORT -Status Take(FunctionContext* context, const Datum& values, const Datum& indices, +Status Take(FunctionContext* ctx, const Datum& values, const Datum& indices, const TakeOptions& options, Datum* out); /// \brief BinaryKernel implementing Take operation class ARROW_EXPORT TakeKernel : public BinaryKernel { public: explicit TakeKernel(const std::shared_ptr<DataType>& type, TakeOptions options = {}) - : type_(type), options_(options) {} + : type_(type) {} + /// \brief BinaryKernel interface + /// + /// delegates to subclasses via Take() Status Call(FunctionContext* ctx, const Datum& values, const Datum& indices, Datum* out) override; + /// \brief output type of this kernel (identical to type of values taken) std::shared_ptr<DataType> out_type() const override { return type_; } - private: + /// \brief factory for TakeKernels + /// + /// \param[in] value_type constructed TakeKernel will support taking + /// values of this type + /// \param[in] index_type constructed TakeKernel will support taking + /// with indices of this type + /// \param[out] out created kernel + static Status Make(const std::shared_ptr<DataType>& value_type, + const std::shared_ptr<DataType>& index_type, + std::unique_ptr<TakeKernel>* out); + + /// \brief single-array implementation + virtual Status Take(FunctionContext* ctx, const Array& values, const Array& indices, + std::shared_ptr<Array>* out) = 0; + + protected: std::shared_ptr<DataType> type_; - TakeOptions options_; }; } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/util-internal.h b/cpp/src/arrow/compute/kernels/util-internal.h index efd990f..c832583 100644 --- a/cpp/src/arrow/compute/kernels/util-internal.h +++ b/cpp/src/arrow/compute/kernels/util-internal.h @@ -131,7 +131,7 @@ class ARROW_EXPORT PrimitiveAllocatingUnaryKernel : public UnaryKernel { /// \brief Kernel used to preallocate outputs for primitive types. class ARROW_EXPORT PrimitiveAllocatingBinaryKernel : public BinaryKernel { public: - // \brief Construct with a kernel to delegate operatoions to. + // \brief Construct with a kernel to delegate operations to. // // Ownership is not taken of the delegate kernel, it must outlive // the life time of this object. diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 655dd38..37da62c 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -51,19 +51,24 @@ def test_sum(arrow_type): ('double', np.arange(0, 0.5, 0.1)), ('string', ['a', 'b', None, 'ddd', 'ee']), ('binary', [b'a', b'b', b'c', b'ddd', b'ee']), - (pa.binary(3), [b'abc', b'bcd', b'cde', b'def', b'efg']) + (pa.binary(3), [b'abc', b'bcd', b'cde', b'def', b'efg']), + (pa.list_(pa.int8()), [[1, 2], [3, 4], [5, 6], None, [9, 16]]), + (pa.struct([('a', pa.int8()), ('b', pa.int8())]), [ + {'a': 1, 'b': 2}, None, {'a': 3, 'b': 4}, None, {'a': 5, 'b': 6}]), ]) def test_take(ty, values): arr = pa.array(values, type=ty) for indices_type in [pa.uint8(), pa.int64()]: indices = pa.array([0, 4, 2, None], type=indices_type) result = arr.take(indices) + result.validate() expected = pa.array([values[0], values[4], values[2], None], type=ty) assert result.equals(expected) # empty indices indices = pa.array([], type=indices_type) result = arr.take(indices) + result.validate() expected = pa.array([], type=ty) assert result.equals(expected) @@ -83,6 +88,7 @@ def test_take_indices_types(): 'uint32', 'int32', 'uint64', 'int64']: indices = pa.array([0, 4, 2, None], type=indices_type) result = arr.take(indices) + result.validate() expected = pa.array([0, 4, 2, None]) assert result.equals(expected) @@ -97,17 +103,7 @@ def test_take_dictionary(ordered): arr = pa.DictionaryArray.from_arrays([0, 1, 2, 0, 1, 2], ['a', 'b', 'c'], ordered=ordered) result = arr.take(pa.array([0, 1, 3])) + result.validate() assert result.to_pylist() == ['a', 'b', 'a'] assert result.dictionary.to_pylist() == ['a', 'b', 'c'] assert result.type.ordered is ordered - - -@pytest.mark.parametrize('array', [ - [[1, 2], [3, 4], [5, 6]], - [{'a': 1, 'b': 2}, None, {'a': 3, 'b': 4}], -], ids=['listarray', 'structarray']) -def test_take_notimplemented(array): - array = pa.array(array) - indices = pa.array([0, 2]) - with pytest.raises(NotImplementedError): - array.take(indices)