pitrou commented on code in PR #41373: URL: https://github.com/apache/arrow/pull/41373#discussion_r1591132130
########## cpp/src/arrow/util/gather_internal.h: ########## @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include "arrow/array/data.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/macros.h" + +// Implementation helpers for kernels that need to load/gather fixed-width Review Comment: If this is kernels-specific then shouldn't it go into `arrow/compute/kernels`? Or do you plan for it to be reused elsewhere? ########## cpp/src/arrow/util/gather_internal.h: ########## @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include "arrow/array/data.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/macros.h" + +// Implementation helpers for kernels that need to load/gather fixed-width +// data from multiple, arbitrary indices. +// +// https://en.wikipedia.org/wiki/Gather/scatter_(vector_addressing) + +namespace arrow::internal { +inline namespace gather_internal { Review Comment: No need for an additional inner namespace IMHO. Also, I'm not sure why it is inline. ########## cpp/src/arrow/compute/kernels/vector_selection_test.cc: ########## @@ -1522,9 +1531,13 @@ TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]", "[[7, 8, null], null, null, [7, 8, null]]"); CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json); + + // No nulls in inner list values trigger the use of FixedWidthTakeExec() in + // FSLTakeExec() Review Comment: Ok, but why remove the existing test? ########## cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc: ########## @@ -324,261 +326,109 @@ namespace { using TakeState = OptionsWrapper<TakeOptions>; // ---------------------------------------------------------------------- -// Implement optimized take for primitive types from boolean to 1/2/4/8/16/32-byte -// C-type based types. Use common implementation for every byte width and only -// generate code for unsigned integer indices, since after boundschecking to -// check for negative numbers in the indices we can safely reinterpret_cast -// signed integers as unsigned. - -/// \brief The Take implementation for primitive (fixed-width) types does not -/// use the logical Arrow type but rather the physical C type. This way we -/// only generate one take function for each byte width. +// Implement optimized take for primitive types from boolean to +// 1/2/4/8/16/32-byte C-type based types and fixed-size binary (0 or more +// bytes). +// +// Use one specialization for each of these primitive byte-widths so the +// compiler can specialize the memcpy to dedicated CPU instructions and for +// fixed-width binary use the 1-byte specialization but pass WithFactor=true +// that makes the kernel consider the factor parameter provided at runtime. +// +// Only unsigned index types need to be instantiated since after +// boundschecking to check for negative numbers in the indices we can safely +// reinterpret_cast signed integers as unsigned. + +/// \brief The Take implementation for primitive types and fixed-width binary. /// /// Also note that this function can also handle fixed-size-list arrays if /// they fit the criteria described in fixed_width_internal.h, so use the /// function defined in that file to access values and destination pointers /// and DO NOT ASSUME `values.type()` is a primitive type. /// /// \pre the indices have been boundschecked -template <typename IndexCType, typename ValueWidthConstant> -struct PrimitiveTakeImpl { - static constexpr int kValueWidth = ValueWidthConstant::value; - - static void Exec(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out_arr) { - DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth); - const auto* values_data = util::OffsetPointerOfFixedWidthValues(values); - const uint8_t* values_is_valid = values.buffers[0].data; - auto values_offset = values.offset; - - const auto* indices_data = indices.GetValues<IndexCType>(1); - const uint8_t* indices_is_valid = indices.buffers[0].data; - auto indices_offset = indices.offset; - - DCHECK_EQ(out_arr->offset, 0); - auto* out = util::MutableFixedWidthValuesPointer(out_arr); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - bit_util::SetBitsTo(out_is_valid, 0, indices.length, false); - } - - auto WriteValue = [&](int64_t position) { - memcpy(out + position * kValueWidth, - values_data + indices_data[position] * kValueWidth, kValueWidth); - }; - - auto WriteZero = [&](int64_t position) { - memset(out + position * kValueWidth, 0, kValueWidth); - }; - - auto WriteZeroSegment = [&](int64_t position, int64_t length) { - memset(out + position * kValueWidth, 0, kValueWidth * length); - }; - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; - int64_t valid_count = 0; - while (position < indices.length) { - BitBlockCount block = indices_bit_counter.NextBlock(); - if (values.null_count == 0) { - // Values are never null, so things are easier - valid_count += block.popcount; - if (block.popcount == block.length) { - // Fastest path: neither values nor index nulls - bit_util::SetBitsTo(out_is_valid, position, block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - WriteValue(position); - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some indices but not all are null - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - bit_util::SetBit(out_is_valid, position); - WriteValue(position); - } else { - WriteZero(position); - } - ++position; - } - } else { - WriteZeroSegment(position, block.length); - position += block.length; - } - } else { - // Values have nulls, so we must do random access into the values bitmap - if (block.popcount == block.length) { - // Faster path: indices are not null but values may be - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - WriteValue(position); - bit_util::SetBit(out_is_valid, position); - ++valid_count; - } else { - WriteZero(position); - } - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null. Since we are doing - // random access in general we have to check the value nullness one by - // one. - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position) && - bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // index is not null && value is not null - WriteValue(position); - bit_util::SetBit(out_is_valid, position); - ++valid_count; - } else { - WriteZero(position); - } - ++position; - } - } else { - WriteZeroSegment(position, block.length); - position += block.length; - } - } +template <typename IndexCType, typename ValueBitWidthConstant, + typename OutputIsZeroInitialized = std::false_type, + typename WithFactor = std::false_type> +struct FixedWidthTakeImpl { + static constexpr int kValueWidthInBits = ValueBitWidthConstant::value; + + // offset returned is defined as number of kValueWidthInBits blocks + static std::pair<int64_t, const uint8_t*> SourceOffsetAndValuesPointer( + const ArraySpan& values) { + if constexpr (kValueWidthInBits == 1) { + DCHECK_EQ(values.type->id(), Type::BOOL); + return {values.offset, values.GetValues<uint8_t>(1, 0)}; + } else { + static_assert(kValueWidthInBits % 8 == 0, + "kValueWidthInBits must be 1 or a multiple of 8"); + return {0, util::OffsetPointerOfFixedWidthValues(values)}; } - out_arr->null_count = out_arr->length - valid_count; } -}; -template <typename IndexCType> -struct BooleanTakeImpl { - static void Exec(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out_arr) { - const uint8_t* values_data = values.buffers[1].data; - const uint8_t* values_is_valid = values.buffers[0].data; - auto values_offset = values.offset; - - const auto* indices_data = indices.GetValues<IndexCType>(1); - const uint8_t* indices_is_valid = indices.buffers[0].data; - auto indices_offset = indices.offset; - - auto out = out_arr->buffers[1]->mutable_data(); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - auto out_offset = out_arr->offset; - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false); - } - // Avoid uninitialized data in values array - bit_util::SetBitsTo(out, out_offset, indices.length, false); - - auto PlaceDataBit = [&](int64_t loc, IndexCType index) { - bit_util::SetBitTo(out, out_offset + loc, - bit_util::GetBit(values_data, values_offset + index)); - }; - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; + static void Exec(KernelContext* ctx, const ArraySpan& values, const ArraySpan& indices, + ArrayData* out_arr, size_t factor) { +#ifndef NDEBUG + int64_t bit_width = util::FixedWidthInBits(*values.type); + DCHECK(WithFactor::value || (kValueWidthInBits == bit_width && factor == 1)); + DCHECK(!WithFactor::value || + (factor > 0 && kValueWidthInBits == 8 && // factors are used with bytes + static_cast<int64_t>(factor * kValueWidthInBits) == bit_width)); +#endif + const bool out_has_validity = values.MayHaveNulls() || indices.MayHaveNulls(); + + const uint8_t* src; + int64_t src_offset; + std::tie(src_offset, src) = SourceOffsetAndValuesPointer(values); + uint8_t* out = util::MutableFixedWidthValuesPointer(out_arr); int64_t valid_count = 0; - while (position < indices.length) { - BitBlockCount block = indices_bit_counter.NextBlock(); - if (values.null_count == 0) { - // Values are never null, so things are easier - valid_count += block.popcount; - if (block.popcount == block.length) { - // Fastest path: neither values nor index nulls - bit_util::SetBitsTo(out_is_valid, out_offset + position, block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - PlaceDataBit(position, indices_data[position]); - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - bit_util::SetBit(out_is_valid, out_offset + position); - PlaceDataBit(position, indices_data[position]); - } - ++position; - } - } else { - position += block.length; - } - } else { - // Values have nulls, so we must do random access into the values bitmap - if (block.popcount == block.length) { - // Faster path: indices are not null but values may be - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - bit_util::SetBit(out_is_valid, out_offset + position); - PlaceDataBit(position, indices_data[position]); - ++valid_count; - } - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null. Since we are doing - // random access in general we have to check the value nullness one by - // one. - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - if (bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - PlaceDataBit(position, indices_data[position]); - bit_util::SetBit(out_is_valid, out_offset + position); - ++valid_count; - } - } - ++position; - } - } else { - position += block.length; - } - } + arrow::internal::Gather<kValueWidthInBits, IndexCType, WithFactor::value> gather{ + /*src_length=*/values.length, + src, + src_offset, + /*idx_length=*/indices.length, + /*idx=*/indices.GetValues<IndexCType>(1), + out, + factor}; + if (out_has_validity) { + DCHECK_EQ(out_arr->offset, 0); + // out_is_valid must be zero-initiliazed, because Gather::Execute + // saves time by not having to ClearBit on every null element. + auto out_is_valid = out_arr->GetMutableValues<uint8_t>(0); + memset(out_is_valid, 0, bit_util::BytesForBits(out_arr->length)); + valid_count = gather.template Execute<OutputIsZeroInitialized::value>( + /*src_validity=*/values, /*idx_validity=*/indices, out_is_valid); + } else { + valid_count = gather.Execute(); } out_arr->null_count = out_arr->length - valid_count; } }; template <template <typename...> class TakeImpl, typename... Args> -void TakeIndexDispatch(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out) { +void TakeIndexDispatch(KernelContext* ctx, const ArraySpan& values, + const ArraySpan& indices, ArrayData* out, size_t factor = 1) { Review Comment: Can we make this `int32_t factor` or `int64_t factor`? ########## cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc: ########## @@ -324,261 +326,109 @@ namespace { using TakeState = OptionsWrapper<TakeOptions>; // ---------------------------------------------------------------------- -// Implement optimized take for primitive types from boolean to 1/2/4/8/16/32-byte -// C-type based types. Use common implementation for every byte width and only -// generate code for unsigned integer indices, since after boundschecking to -// check for negative numbers in the indices we can safely reinterpret_cast -// signed integers as unsigned. - -/// \brief The Take implementation for primitive (fixed-width) types does not -/// use the logical Arrow type but rather the physical C type. This way we -/// only generate one take function for each byte width. +// Implement optimized take for primitive types from boolean to +// 1/2/4/8/16/32-byte C-type based types and fixed-size binary (0 or more +// bytes). +// +// Use one specialization for each of these primitive byte-widths so the +// compiler can specialize the memcpy to dedicated CPU instructions and for +// fixed-width binary use the 1-byte specialization but pass WithFactor=true +// that makes the kernel consider the factor parameter provided at runtime. +// +// Only unsigned index types need to be instantiated since after +// boundschecking to check for negative numbers in the indices we can safely +// reinterpret_cast signed integers as unsigned. + +/// \brief The Take implementation for primitive types and fixed-width binary. /// /// Also note that this function can also handle fixed-size-list arrays if /// they fit the criteria described in fixed_width_internal.h, so use the /// function defined in that file to access values and destination pointers /// and DO NOT ASSUME `values.type()` is a primitive type. /// /// \pre the indices have been boundschecked -template <typename IndexCType, typename ValueWidthConstant> -struct PrimitiveTakeImpl { - static constexpr int kValueWidth = ValueWidthConstant::value; - - static void Exec(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out_arr) { - DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth); - const auto* values_data = util::OffsetPointerOfFixedWidthValues(values); - const uint8_t* values_is_valid = values.buffers[0].data; - auto values_offset = values.offset; - - const auto* indices_data = indices.GetValues<IndexCType>(1); - const uint8_t* indices_is_valid = indices.buffers[0].data; - auto indices_offset = indices.offset; - - DCHECK_EQ(out_arr->offset, 0); - auto* out = util::MutableFixedWidthValuesPointer(out_arr); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - bit_util::SetBitsTo(out_is_valid, 0, indices.length, false); - } - - auto WriteValue = [&](int64_t position) { - memcpy(out + position * kValueWidth, - values_data + indices_data[position] * kValueWidth, kValueWidth); - }; - - auto WriteZero = [&](int64_t position) { - memset(out + position * kValueWidth, 0, kValueWidth); - }; - - auto WriteZeroSegment = [&](int64_t position, int64_t length) { - memset(out + position * kValueWidth, 0, kValueWidth * length); - }; - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; - int64_t valid_count = 0; - while (position < indices.length) { - BitBlockCount block = indices_bit_counter.NextBlock(); - if (values.null_count == 0) { - // Values are never null, so things are easier - valid_count += block.popcount; - if (block.popcount == block.length) { - // Fastest path: neither values nor index nulls - bit_util::SetBitsTo(out_is_valid, position, block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - WriteValue(position); - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some indices but not all are null - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - bit_util::SetBit(out_is_valid, position); - WriteValue(position); - } else { - WriteZero(position); - } - ++position; - } - } else { - WriteZeroSegment(position, block.length); - position += block.length; - } - } else { - // Values have nulls, so we must do random access into the values bitmap - if (block.popcount == block.length) { - // Faster path: indices are not null but values may be - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - WriteValue(position); - bit_util::SetBit(out_is_valid, position); - ++valid_count; - } else { - WriteZero(position); - } - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null. Since we are doing - // random access in general we have to check the value nullness one by - // one. - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position) && - bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // index is not null && value is not null - WriteValue(position); - bit_util::SetBit(out_is_valid, position); - ++valid_count; - } else { - WriteZero(position); - } - ++position; - } - } else { - WriteZeroSegment(position, block.length); - position += block.length; - } - } +template <typename IndexCType, typename ValueBitWidthConstant, + typename OutputIsZeroInitialized = std::false_type, + typename WithFactor = std::false_type> Review Comment: Why not `bool kOutputIsZeroInitialized` and `bool kWithFactor = false`? ########## cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc: ########## @@ -324,261 +326,109 @@ namespace { using TakeState = OptionsWrapper<TakeOptions>; // ---------------------------------------------------------------------- -// Implement optimized take for primitive types from boolean to 1/2/4/8/16/32-byte -// C-type based types. Use common implementation for every byte width and only -// generate code for unsigned integer indices, since after boundschecking to -// check for negative numbers in the indices we can safely reinterpret_cast -// signed integers as unsigned. - -/// \brief The Take implementation for primitive (fixed-width) types does not -/// use the logical Arrow type but rather the physical C type. This way we -/// only generate one take function for each byte width. +// Implement optimized take for primitive types from boolean to +// 1/2/4/8/16/32-byte C-type based types and fixed-size binary (0 or more +// bytes). +// +// Use one specialization for each of these primitive byte-widths so the +// compiler can specialize the memcpy to dedicated CPU instructions and for +// fixed-width binary use the 1-byte specialization but pass WithFactor=true +// that makes the kernel consider the factor parameter provided at runtime. +// +// Only unsigned index types need to be instantiated since after +// boundschecking to check for negative numbers in the indices we can safely +// reinterpret_cast signed integers as unsigned. + +/// \brief The Take implementation for primitive types and fixed-width binary. /// /// Also note that this function can also handle fixed-size-list arrays if /// they fit the criteria described in fixed_width_internal.h, so use the /// function defined in that file to access values and destination pointers /// and DO NOT ASSUME `values.type()` is a primitive type. /// /// \pre the indices have been boundschecked -template <typename IndexCType, typename ValueWidthConstant> -struct PrimitiveTakeImpl { - static constexpr int kValueWidth = ValueWidthConstant::value; - - static void Exec(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out_arr) { - DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth); - const auto* values_data = util::OffsetPointerOfFixedWidthValues(values); - const uint8_t* values_is_valid = values.buffers[0].data; - auto values_offset = values.offset; - - const auto* indices_data = indices.GetValues<IndexCType>(1); - const uint8_t* indices_is_valid = indices.buffers[0].data; - auto indices_offset = indices.offset; - - DCHECK_EQ(out_arr->offset, 0); - auto* out = util::MutableFixedWidthValuesPointer(out_arr); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - bit_util::SetBitsTo(out_is_valid, 0, indices.length, false); - } - - auto WriteValue = [&](int64_t position) { - memcpy(out + position * kValueWidth, - values_data + indices_data[position] * kValueWidth, kValueWidth); - }; - - auto WriteZero = [&](int64_t position) { - memset(out + position * kValueWidth, 0, kValueWidth); - }; - - auto WriteZeroSegment = [&](int64_t position, int64_t length) { - memset(out + position * kValueWidth, 0, kValueWidth * length); - }; - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; - int64_t valid_count = 0; - while (position < indices.length) { - BitBlockCount block = indices_bit_counter.NextBlock(); - if (values.null_count == 0) { - // Values are never null, so things are easier - valid_count += block.popcount; - if (block.popcount == block.length) { - // Fastest path: neither values nor index nulls - bit_util::SetBitsTo(out_is_valid, position, block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - WriteValue(position); - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some indices but not all are null - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - bit_util::SetBit(out_is_valid, position); - WriteValue(position); - } else { - WriteZero(position); - } - ++position; - } - } else { - WriteZeroSegment(position, block.length); - position += block.length; - } - } else { - // Values have nulls, so we must do random access into the values bitmap - if (block.popcount == block.length) { - // Faster path: indices are not null but values may be - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - WriteValue(position); - bit_util::SetBit(out_is_valid, position); - ++valid_count; - } else { - WriteZero(position); - } - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null. Since we are doing - // random access in general we have to check the value nullness one by - // one. - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position) && - bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // index is not null && value is not null - WriteValue(position); - bit_util::SetBit(out_is_valid, position); - ++valid_count; - } else { - WriteZero(position); - } - ++position; - } - } else { - WriteZeroSegment(position, block.length); - position += block.length; - } - } +template <typename IndexCType, typename ValueBitWidthConstant, + typename OutputIsZeroInitialized = std::false_type, + typename WithFactor = std::false_type> +struct FixedWidthTakeImpl { + static constexpr int kValueWidthInBits = ValueBitWidthConstant::value; + + // offset returned is defined as number of kValueWidthInBits blocks + static std::pair<int64_t, const uint8_t*> SourceOffsetAndValuesPointer( + const ArraySpan& values) { + if constexpr (kValueWidthInBits == 1) { + DCHECK_EQ(values.type->id(), Type::BOOL); + return {values.offset, values.GetValues<uint8_t>(1, 0)}; + } else { + static_assert(kValueWidthInBits % 8 == 0, + "kValueWidthInBits must be 1 or a multiple of 8"); + return {0, util::OffsetPointerOfFixedWidthValues(values)}; } - out_arr->null_count = out_arr->length - valid_count; } -}; -template <typename IndexCType> -struct BooleanTakeImpl { - static void Exec(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out_arr) { - const uint8_t* values_data = values.buffers[1].data; - const uint8_t* values_is_valid = values.buffers[0].data; - auto values_offset = values.offset; - - const auto* indices_data = indices.GetValues<IndexCType>(1); - const uint8_t* indices_is_valid = indices.buffers[0].data; - auto indices_offset = indices.offset; - - auto out = out_arr->buffers[1]->mutable_data(); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - auto out_offset = out_arr->offset; - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false); - } - // Avoid uninitialized data in values array - bit_util::SetBitsTo(out, out_offset, indices.length, false); - - auto PlaceDataBit = [&](int64_t loc, IndexCType index) { - bit_util::SetBitTo(out, out_offset + loc, - bit_util::GetBit(values_data, values_offset + index)); - }; - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; + static void Exec(KernelContext* ctx, const ArraySpan& values, const ArraySpan& indices, + ArrayData* out_arr, size_t factor) { +#ifndef NDEBUG + int64_t bit_width = util::FixedWidthInBits(*values.type); + DCHECK(WithFactor::value || (kValueWidthInBits == bit_width && factor == 1)); + DCHECK(!WithFactor::value || + (factor > 0 && kValueWidthInBits == 8 && // factors are used with bytes + static_cast<int64_t>(factor * kValueWidthInBits) == bit_width)); +#endif + const bool out_has_validity = values.MayHaveNulls() || indices.MayHaveNulls(); + + const uint8_t* src; + int64_t src_offset; + std::tie(src_offset, src) = SourceOffsetAndValuesPointer(values); Review Comment: I think this can be `[const uint8_t* src, int64_t src_offset] = SourceOffsetAndValuesPointer(values)` ########## cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc: ########## @@ -324,261 +326,109 @@ namespace { using TakeState = OptionsWrapper<TakeOptions>; // ---------------------------------------------------------------------- -// Implement optimized take for primitive types from boolean to 1/2/4/8/16/32-byte -// C-type based types. Use common implementation for every byte width and only -// generate code for unsigned integer indices, since after boundschecking to -// check for negative numbers in the indices we can safely reinterpret_cast -// signed integers as unsigned. - -/// \brief The Take implementation for primitive (fixed-width) types does not -/// use the logical Arrow type but rather the physical C type. This way we -/// only generate one take function for each byte width. +// Implement optimized take for primitive types from boolean to +// 1/2/4/8/16/32-byte C-type based types and fixed-size binary (0 or more +// bytes). +// +// Use one specialization for each of these primitive byte-widths so the +// compiler can specialize the memcpy to dedicated CPU instructions and for +// fixed-width binary use the 1-byte specialization but pass WithFactor=true +// that makes the kernel consider the factor parameter provided at runtime. +// +// Only unsigned index types need to be instantiated since after +// boundschecking to check for negative numbers in the indices we can safely +// reinterpret_cast signed integers as unsigned. + +/// \brief The Take implementation for primitive types and fixed-width binary. /// /// Also note that this function can also handle fixed-size-list arrays if /// they fit the criteria described in fixed_width_internal.h, so use the /// function defined in that file to access values and destination pointers /// and DO NOT ASSUME `values.type()` is a primitive type. /// /// \pre the indices have been boundschecked -template <typename IndexCType, typename ValueWidthConstant> -struct PrimitiveTakeImpl { - static constexpr int kValueWidth = ValueWidthConstant::value; - - static void Exec(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out_arr) { - DCHECK_EQ(util::FixedWidthInBytes(*values.type), kValueWidth); - const auto* values_data = util::OffsetPointerOfFixedWidthValues(values); - const uint8_t* values_is_valid = values.buffers[0].data; - auto values_offset = values.offset; - - const auto* indices_data = indices.GetValues<IndexCType>(1); - const uint8_t* indices_is_valid = indices.buffers[0].data; - auto indices_offset = indices.offset; - - DCHECK_EQ(out_arr->offset, 0); - auto* out = util::MutableFixedWidthValuesPointer(out_arr); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - bit_util::SetBitsTo(out_is_valid, 0, indices.length, false); - } - - auto WriteValue = [&](int64_t position) { - memcpy(out + position * kValueWidth, - values_data + indices_data[position] * kValueWidth, kValueWidth); - }; - - auto WriteZero = [&](int64_t position) { - memset(out + position * kValueWidth, 0, kValueWidth); - }; - - auto WriteZeroSegment = [&](int64_t position, int64_t length) { - memset(out + position * kValueWidth, 0, kValueWidth * length); - }; - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; - int64_t valid_count = 0; - while (position < indices.length) { - BitBlockCount block = indices_bit_counter.NextBlock(); - if (values.null_count == 0) { - // Values are never null, so things are easier - valid_count += block.popcount; - if (block.popcount == block.length) { - // Fastest path: neither values nor index nulls - bit_util::SetBitsTo(out_is_valid, position, block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - WriteValue(position); - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some indices but not all are null - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - bit_util::SetBit(out_is_valid, position); - WriteValue(position); - } else { - WriteZero(position); - } - ++position; - } - } else { - WriteZeroSegment(position, block.length); - position += block.length; - } - } else { - // Values have nulls, so we must do random access into the values bitmap - if (block.popcount == block.length) { - // Faster path: indices are not null but values may be - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - WriteValue(position); - bit_util::SetBit(out_is_valid, position); - ++valid_count; - } else { - WriteZero(position); - } - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null. Since we are doing - // random access in general we have to check the value nullness one by - // one. - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position) && - bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // index is not null && value is not null - WriteValue(position); - bit_util::SetBit(out_is_valid, position); - ++valid_count; - } else { - WriteZero(position); - } - ++position; - } - } else { - WriteZeroSegment(position, block.length); - position += block.length; - } - } +template <typename IndexCType, typename ValueBitWidthConstant, + typename OutputIsZeroInitialized = std::false_type, + typename WithFactor = std::false_type> +struct FixedWidthTakeImpl { + static constexpr int kValueWidthInBits = ValueBitWidthConstant::value; + + // offset returned is defined as number of kValueWidthInBits blocks + static std::pair<int64_t, const uint8_t*> SourceOffsetAndValuesPointer( + const ArraySpan& values) { + if constexpr (kValueWidthInBits == 1) { + DCHECK_EQ(values.type->id(), Type::BOOL); + return {values.offset, values.GetValues<uint8_t>(1, 0)}; + } else { + static_assert(kValueWidthInBits % 8 == 0, + "kValueWidthInBits must be 1 or a multiple of 8"); + return {0, util::OffsetPointerOfFixedWidthValues(values)}; } - out_arr->null_count = out_arr->length - valid_count; } -}; -template <typename IndexCType> -struct BooleanTakeImpl { - static void Exec(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out_arr) { - const uint8_t* values_data = values.buffers[1].data; - const uint8_t* values_is_valid = values.buffers[0].data; - auto values_offset = values.offset; - - const auto* indices_data = indices.GetValues<IndexCType>(1); - const uint8_t* indices_is_valid = indices.buffers[0].data; - auto indices_offset = indices.offset; - - auto out = out_arr->buffers[1]->mutable_data(); - auto out_is_valid = out_arr->buffers[0]->mutable_data(); - auto out_offset = out_arr->offset; - - // If either the values or indices have nulls, we preemptively zero out the - // out validity bitmap so that we don't have to use ClearBit in each - // iteration for nulls. - if (values.null_count != 0 || indices.null_count != 0) { - bit_util::SetBitsTo(out_is_valid, out_offset, indices.length, false); - } - // Avoid uninitialized data in values array - bit_util::SetBitsTo(out, out_offset, indices.length, false); - - auto PlaceDataBit = [&](int64_t loc, IndexCType index) { - bit_util::SetBitTo(out, out_offset + loc, - bit_util::GetBit(values_data, values_offset + index)); - }; - - OptionalBitBlockCounter indices_bit_counter(indices_is_valid, indices_offset, - indices.length); - int64_t position = 0; + static void Exec(KernelContext* ctx, const ArraySpan& values, const ArraySpan& indices, + ArrayData* out_arr, size_t factor) { +#ifndef NDEBUG + int64_t bit_width = util::FixedWidthInBits(*values.type); + DCHECK(WithFactor::value || (kValueWidthInBits == bit_width && factor == 1)); + DCHECK(!WithFactor::value || + (factor > 0 && kValueWidthInBits == 8 && // factors are used with bytes + static_cast<int64_t>(factor * kValueWidthInBits) == bit_width)); +#endif + const bool out_has_validity = values.MayHaveNulls() || indices.MayHaveNulls(); + + const uint8_t* src; + int64_t src_offset; + std::tie(src_offset, src) = SourceOffsetAndValuesPointer(values); + uint8_t* out = util::MutableFixedWidthValuesPointer(out_arr); int64_t valid_count = 0; - while (position < indices.length) { - BitBlockCount block = indices_bit_counter.NextBlock(); - if (values.null_count == 0) { - // Values are never null, so things are easier - valid_count += block.popcount; - if (block.popcount == block.length) { - // Fastest path: neither values nor index nulls - bit_util::SetBitsTo(out_is_valid, out_offset + position, block.length, true); - for (int64_t i = 0; i < block.length; ++i) { - PlaceDataBit(position, indices_data[position]); - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - bit_util::SetBit(out_is_valid, out_offset + position); - PlaceDataBit(position, indices_data[position]); - } - ++position; - } - } else { - position += block.length; - } - } else { - // Values have nulls, so we must do random access into the values bitmap - if (block.popcount == block.length) { - // Faster path: indices are not null but values may be - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - bit_util::SetBit(out_is_valid, out_offset + position); - PlaceDataBit(position, indices_data[position]); - ++valid_count; - } - ++position; - } - } else if (block.popcount > 0) { - // Slow path: some but not all indices are null. Since we are doing - // random access in general we have to check the value nullness one by - // one. - for (int64_t i = 0; i < block.length; ++i) { - if (bit_util::GetBit(indices_is_valid, indices_offset + position)) { - // index is not null - if (bit_util::GetBit(values_is_valid, - values_offset + indices_data[position])) { - // value is not null - PlaceDataBit(position, indices_data[position]); - bit_util::SetBit(out_is_valid, out_offset + position); - ++valid_count; - } - } - ++position; - } - } else { - position += block.length; - } - } + arrow::internal::Gather<kValueWidthInBits, IndexCType, WithFactor::value> gather{ + /*src_length=*/values.length, + src, + src_offset, + /*idx_length=*/indices.length, + /*idx=*/indices.GetValues<IndexCType>(1), + out, + factor}; + if (out_has_validity) { + DCHECK_EQ(out_arr->offset, 0); + // out_is_valid must be zero-initiliazed, because Gather::Execute + // saves time by not having to ClearBit on every null element. + auto out_is_valid = out_arr->GetMutableValues<uint8_t>(0); + memset(out_is_valid, 0, bit_util::BytesForBits(out_arr->length)); + valid_count = gather.template Execute<OutputIsZeroInitialized::value>( + /*src_validity=*/values, /*idx_validity=*/indices, out_is_valid); + } else { + valid_count = gather.Execute(); } out_arr->null_count = out_arr->length - valid_count; } }; template <template <typename...> class TakeImpl, typename... Args> -void TakeIndexDispatch(const ArraySpan& values, const ArraySpan& indices, - ArrayData* out) { +void TakeIndexDispatch(KernelContext* ctx, const ArraySpan& values, + const ArraySpan& indices, ArrayData* out, size_t factor = 1) { // With the simplifying assumption that boundschecking has taken place // already at a higher level, we can now assume that the index values are all // non-negative. Thus, we can interpret signed integers as unsigned and avoid // having to generate double the amount of binary code to handle each integer // width. switch (indices.type->byte_width()) { case 1: - return TakeImpl<uint8_t, Args...>::Exec(values, indices, out); + return TakeImpl<uint8_t, Args...>::Exec(ctx, values, indices, out, factor); case 2: - return TakeImpl<uint16_t, Args...>::Exec(values, indices, out); + return TakeImpl<uint16_t, Args...>::Exec(ctx, values, indices, out, factor); case 4: - return TakeImpl<uint32_t, Args...>::Exec(values, indices, out); + return TakeImpl<uint32_t, Args...>::Exec(ctx, values, indices, out, factor); case 8: - return TakeImpl<uint64_t, Args...>::Exec(values, indices, out); - default: - DCHECK(false) << "Invalid indices byte width"; - break; + return TakeImpl<uint64_t, Args...>::Exec(ctx, values, indices, out, factor); } + DCHECK(false) << "Invalid indices byte width"; Review Comment: Won't the absence of a "default" clause above trigger some compiler warnings? ########## cpp/src/arrow/util/gather_internal.h: ########## @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include "arrow/array/data.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/macros.h" + +// Implementation helpers for kernels that need to load/gather fixed-width +// data from multiple, arbitrary indices. +// +// https://en.wikipedia.org/wiki/Gather/scatter_(vector_addressing) + +namespace arrow::internal { +inline namespace gather_internal { +// CRTP [1] base class for Gather that provides a gathering loop in terms of +// Write*() methods that must be implemented by the derived class. +// +// [1] https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern +template <class GatherImpl> +class GatherBaseCRTP { + public: + // Output offset is not supported by Gather and idx is supposed to have offset + // pre-applied. idx_validity parameters on functions can use the offset they + // carry to read the validity bitmap as bitmaps can't have pre-applied offsets + // (they might not align to byte boundaries). + GatherBaseCRTP() = default; + GatherBaseCRTP(const GatherBaseCRTP&) = delete; + GatherBaseCRTP(GatherBaseCRTP&&) = delete; + GatherBaseCRTP& operator=(const GatherBaseCRTP&) = delete; + GatherBaseCRTP& operator=(GatherBaseCRTP&&) = delete; + + protected: + ARROW_FORCE_INLINE int64_t ExecuteNoNulls(int64_t idx_length) { + auto* self = static_cast<GatherImpl*>(this); + for (int64_t position = 0; position < idx_length; position++) { + self->WriteValue(position); + } + return idx_length; + } + + // See derived Gather classes below for the meaning of the parameters, pre and + // post-conditions. + template <bool kOutputIsZeroInitialized, typename IndexCType> + ARROW_FORCE_INLINE int64_t ExecuteWithNulls(const ArraySpan& src_validity, + int64_t idx_length, const IndexCType* idx, + const ArraySpan& idx_validity, + uint8_t* out_is_valid) { + auto* self = static_cast<GatherImpl*>(this); + OptionalBitBlockCounter indices_bit_counter(idx_validity.buffers[0].data, + idx_validity.offset, idx_length); + int64_t position = 0; + int64_t valid_count = 0; + while (position < idx_length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (!src_validity.MayHaveNulls()) { + // Source values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither source values nor index nulls + bit_util::SetBitsTo(out_is_valid, position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + self->WriteValue(position); + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some indices but not all are null + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); + if (idx_validity.IsValid(position)) { + // index is not null + bit_util::SetBit(out_is_valid, position); + self->WriteValue(position); + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else { + self->WriteZeroSegment(position, block.length); + position += block.length; + } + } else { + // Source values may be null, so we must do random access into src_validity + if (block.popcount == block.length) { + // Faster path: indices are not null but source values may be + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + if (src_validity.IsValid(idx[position])) { + // value is not null + self->WriteValue(position); + bit_util::SetBit(out_is_valid, position); + ++valid_count; + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); + if (idx_validity.IsValid(position) && src_validity.IsValid(idx[position])) { + // index is not null && value is not null + self->WriteValue(position); + bit_util::SetBit(out_is_valid, position); + ++valid_count; + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else { + if constexpr (!kOutputIsZeroInitialized) { + self->WriteZeroSegment(position, block.length); + } + position += block.length; + } + } + } + return valid_count; + } +}; + +template <int kValueWidthInBits, typename IndexCType, bool WithFactor, + std::enable_if_t<kValueWidthInBits % 8 == 0 || kValueWidthInBits == 1, bool> = + true> +class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType, WithFactor>> { + public: + static constexpr int kValueWidth = kValueWidthInBits / 8; + + private: + const int64_t src_length_; // number of elements of kValueWidth bytes in src_ + const uint8_t* src_; + const int64_t idx_length_; // number IndexCType elements in idx_ + const IndexCType* idx_; + uint8_t* out_; + size_t factor_; + + public: + void WriteValue(int64_t position) { + if constexpr (WithFactor) { + const size_t scaled_factor = kValueWidth * factor_; + memcpy(out_ + position * scaled_factor, src_ + idx_[position] * scaled_factor, + scaled_factor); + } else { + memcpy(out_ + position * kValueWidth, src_ + idx_[position] * kValueWidth, + kValueWidth); + } + } + + void WriteZero(int64_t position) { + if constexpr (WithFactor) { + const size_t scaled_factor = kValueWidth * factor_; + memset(out_ + position * scaled_factor, 0, scaled_factor); + } else { + memset(out_ + position * kValueWidth, 0, kValueWidth); + } + } + + void WriteZeroSegment(int64_t position, int64_t length) { + if constexpr (WithFactor) { + const size_t scaled_factor = kValueWidth * factor_; + memset(out_ + position * scaled_factor, 0, length * scaled_factor); + } else { + memset(out_ + position * kValueWidth, 0, length * kValueWidth); + } + } + + public: + Gather(int64_t src_length, const uint8_t* src, int64_t zero_src_offset, + int64_t idx_length, const IndexCType* idx, uint8_t* out, size_t factor) + : src_length_(src_length), + src_(src), + idx_length_(idx_length), + idx_(idx), + out_(out), + factor_(factor) { + assert(zero_src_offset == 0); + assert(src && idx && out); + assert((WithFactor || factor == 1) && + "When WithFactor is false, the factor is assumed to be 1 at compile time"); + } + + ARROW_FORCE_INLINE int64_t Execute() { return this->ExecuteNoNulls(idx_length_); } + + /// \pre If kOutputIsZeroInitialized, then this->out_ has to be zero initialized. + /// \pre Bits in out_is_valid have to always be zero initialized. + /// \post The bits for the valid elements (and only those) are set in out_is_valid. + /// \post If !kOutputIsZeroInitialized, then positions in this->_out containing null + /// elements have 0s written to them. This might be less efficient than + /// zero-initializing first and calling this->Execute() afterwards. + /// \return The number of valid elements in out. + template <bool kOutputIsZeroInitialized = false> + ARROW_FORCE_INLINE int64_t Execute(const ArraySpan& src_validity, + const ArraySpan& idx_validity, + uint8_t* out_is_valid) { + assert(src_length_ == src_validity.length); + assert(idx_length_ == idx_validity.length); + assert(out_is_valid); + return this->template ExecuteWithNulls<kOutputIsZeroInitialized>( + src_validity, idx_length_, idx_, idx_validity, out_is_valid); + } +}; + +template <typename IndexCType> +class Gather<1, IndexCType, /*WithFactor=*/false> Review Comment: ```suggestion class Gather</*kValueWidthInBits=*/ 1, IndexCType, /*WithFactor=*/false> ``` ########## cpp/src/arrow/util/gather_internal.h: ########## @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include "arrow/array/data.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/macros.h" + +// Implementation helpers for kernels that need to load/gather fixed-width +// data from multiple, arbitrary indices. +// +// https://en.wikipedia.org/wiki/Gather/scatter_(vector_addressing) + +namespace arrow::internal { +inline namespace gather_internal { +// CRTP [1] base class for Gather that provides a gathering loop in terms of +// Write*() methods that must be implemented by the derived class. +// +// [1] https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern +template <class GatherImpl> +class GatherBaseCRTP { + public: + // Output offset is not supported by Gather and idx is supposed to have offset + // pre-applied. idx_validity parameters on functions can use the offset they + // carry to read the validity bitmap as bitmaps can't have pre-applied offsets + // (they might not align to byte boundaries). + GatherBaseCRTP() = default; + GatherBaseCRTP(const GatherBaseCRTP&) = delete; + GatherBaseCRTP(GatherBaseCRTP&&) = delete; + GatherBaseCRTP& operator=(const GatherBaseCRTP&) = delete; + GatherBaseCRTP& operator=(GatherBaseCRTP&&) = delete; + + protected: + ARROW_FORCE_INLINE int64_t ExecuteNoNulls(int64_t idx_length) { + auto* self = static_cast<GatherImpl*>(this); + for (int64_t position = 0; position < idx_length; position++) { + self->WriteValue(position); + } + return idx_length; + } + + // See derived Gather classes below for the meaning of the parameters, pre and + // post-conditions. + template <bool kOutputIsZeroInitialized, typename IndexCType> + ARROW_FORCE_INLINE int64_t ExecuteWithNulls(const ArraySpan& src_validity, + int64_t idx_length, const IndexCType* idx, + const ArraySpan& idx_validity, + uint8_t* out_is_valid) { + auto* self = static_cast<GatherImpl*>(this); + OptionalBitBlockCounter indices_bit_counter(idx_validity.buffers[0].data, + idx_validity.offset, idx_length); + int64_t position = 0; + int64_t valid_count = 0; + while (position < idx_length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (!src_validity.MayHaveNulls()) { + // Source values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither source values nor index nulls + bit_util::SetBitsTo(out_is_valid, position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + self->WriteValue(position); + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some indices but not all are null + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); + if (idx_validity.IsValid(position)) { + // index is not null + bit_util::SetBit(out_is_valid, position); + self->WriteValue(position); + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else { + self->WriteZeroSegment(position, block.length); + position += block.length; + } + } else { + // Source values may be null, so we must do random access into src_validity + if (block.popcount == block.length) { + // Faster path: indices are not null but source values may be + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + if (src_validity.IsValid(idx[position])) { + // value is not null + self->WriteValue(position); + bit_util::SetBit(out_is_valid, position); + ++valid_count; + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); + if (idx_validity.IsValid(position) && src_validity.IsValid(idx[position])) { + // index is not null && value is not null + self->WriteValue(position); + bit_util::SetBit(out_is_valid, position); + ++valid_count; + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else { + if constexpr (!kOutputIsZeroInitialized) { + self->WriteZeroSegment(position, block.length); + } + position += block.length; + } + } + } + return valid_count; + } +}; + +template <int kValueWidthInBits, typename IndexCType, bool WithFactor, Review Comment: ```suggestion template <int kValueWidthInBits, typename IndexCType, bool kWithFactor, ``` ########## cpp/src/arrow/util/macros.h: ########## @@ -102,7 +102,7 @@ #elif defined(_MSC_VER) // MSVC #define ARROW_NORETURN __declspec(noreturn) #define ARROW_NOINLINE __declspec(noinline) -#define ARROW_FORCE_INLINE __declspec(forceinline) +#define ARROW_FORCE_INLINE __forceinline Review Comment: Any reason for this change? ########## cpp/src/arrow/util/gather_internal.h: ########## @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include "arrow/array/data.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/macros.h" + +// Implementation helpers for kernels that need to load/gather fixed-width +// data from multiple, arbitrary indices. +// +// https://en.wikipedia.org/wiki/Gather/scatter_(vector_addressing) + +namespace arrow::internal { +inline namespace gather_internal { +// CRTP [1] base class for Gather that provides a gathering loop in terms of +// Write*() methods that must be implemented by the derived class. +// +// [1] https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern +template <class GatherImpl> +class GatherBaseCRTP { + public: + // Output offset is not supported by Gather and idx is supposed to have offset + // pre-applied. idx_validity parameters on functions can use the offset they + // carry to read the validity bitmap as bitmaps can't have pre-applied offsets + // (they might not align to byte boundaries). + GatherBaseCRTP() = default; + GatherBaseCRTP(const GatherBaseCRTP&) = delete; + GatherBaseCRTP(GatherBaseCRTP&&) = delete; + GatherBaseCRTP& operator=(const GatherBaseCRTP&) = delete; + GatherBaseCRTP& operator=(GatherBaseCRTP&&) = delete; + + protected: + ARROW_FORCE_INLINE int64_t ExecuteNoNulls(int64_t idx_length) { + auto* self = static_cast<GatherImpl*>(this); + for (int64_t position = 0; position < idx_length; position++) { + self->WriteValue(position); + } + return idx_length; + } + + // See derived Gather classes below for the meaning of the parameters, pre and + // post-conditions. + template <bool kOutputIsZeroInitialized, typename IndexCType> + ARROW_FORCE_INLINE int64_t ExecuteWithNulls(const ArraySpan& src_validity, + int64_t idx_length, const IndexCType* idx, + const ArraySpan& idx_validity, + uint8_t* out_is_valid) { + auto* self = static_cast<GatherImpl*>(this); + OptionalBitBlockCounter indices_bit_counter(idx_validity.buffers[0].data, + idx_validity.offset, idx_length); + int64_t position = 0; + int64_t valid_count = 0; + while (position < idx_length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (!src_validity.MayHaveNulls()) { + // Source values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither source values nor index nulls + bit_util::SetBitsTo(out_is_valid, position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + self->WriteValue(position); + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some indices but not all are null + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); + if (idx_validity.IsValid(position)) { + // index is not null + bit_util::SetBit(out_is_valid, position); + self->WriteValue(position); + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else { + self->WriteZeroSegment(position, block.length); + position += block.length; + } + } else { + // Source values may be null, so we must do random access into src_validity + if (block.popcount == block.length) { + // Faster path: indices are not null but source values may be + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + if (src_validity.IsValid(idx[position])) { + // value is not null + self->WriteValue(position); + bit_util::SetBit(out_is_valid, position); + ++valid_count; + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); + if (idx_validity.IsValid(position) && src_validity.IsValid(idx[position])) { + // index is not null && value is not null + self->WriteValue(position); + bit_util::SetBit(out_is_valid, position); + ++valid_count; + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else { + if constexpr (!kOutputIsZeroInitialized) { + self->WriteZeroSegment(position, block.length); + } + position += block.length; + } + } + } + return valid_count; + } +}; + +template <int kValueWidthInBits, typename IndexCType, bool WithFactor, + std::enable_if_t<kValueWidthInBits % 8 == 0 || kValueWidthInBits == 1, bool> = Review Comment: Should this condition be a `static_assert` instead? ########## cpp/src/arrow/util/gather_internal.h: ########## @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include "arrow/array/data.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/macros.h" + +// Implementation helpers for kernels that need to load/gather fixed-width +// data from multiple, arbitrary indices. +// +// https://en.wikipedia.org/wiki/Gather/scatter_(vector_addressing) + +namespace arrow::internal { +inline namespace gather_internal { +// CRTP [1] base class for Gather that provides a gathering loop in terms of +// Write*() methods that must be implemented by the derived class. +// +// [1] https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern +template <class GatherImpl> +class GatherBaseCRTP { + public: + // Output offset is not supported by Gather and idx is supposed to have offset + // pre-applied. idx_validity parameters on functions can use the offset they + // carry to read the validity bitmap as bitmaps can't have pre-applied offsets + // (they might not align to byte boundaries). + GatherBaseCRTP() = default; + GatherBaseCRTP(const GatherBaseCRTP&) = delete; + GatherBaseCRTP(GatherBaseCRTP&&) = delete; + GatherBaseCRTP& operator=(const GatherBaseCRTP&) = delete; + GatherBaseCRTP& operator=(GatherBaseCRTP&&) = delete; + + protected: + ARROW_FORCE_INLINE int64_t ExecuteNoNulls(int64_t idx_length) { + auto* self = static_cast<GatherImpl*>(this); + for (int64_t position = 0; position < idx_length; position++) { + self->WriteValue(position); + } + return idx_length; + } + + // See derived Gather classes below for the meaning of the parameters, pre and + // post-conditions. + template <bool kOutputIsZeroInitialized, typename IndexCType> + ARROW_FORCE_INLINE int64_t ExecuteWithNulls(const ArraySpan& src_validity, + int64_t idx_length, const IndexCType* idx, + const ArraySpan& idx_validity, + uint8_t* out_is_valid) { + auto* self = static_cast<GatherImpl*>(this); + OptionalBitBlockCounter indices_bit_counter(idx_validity.buffers[0].data, + idx_validity.offset, idx_length); + int64_t position = 0; + int64_t valid_count = 0; + while (position < idx_length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + if (!src_validity.MayHaveNulls()) { + // Source values are never null, so things are easier + valid_count += block.popcount; + if (block.popcount == block.length) { + // Fastest path: neither source values nor index nulls + bit_util::SetBitsTo(out_is_valid, position, block.length, true); + for (int64_t i = 0; i < block.length; ++i) { + self->WriteValue(position); + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some indices but not all are null + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); + if (idx_validity.IsValid(position)) { + // index is not null + bit_util::SetBit(out_is_valid, position); + self->WriteValue(position); + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else { + self->WriteZeroSegment(position, block.length); + position += block.length; + } + } else { + // Source values may be null, so we must do random access into src_validity + if (block.popcount == block.length) { + // Faster path: indices are not null but source values may be + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + if (src_validity.IsValid(idx[position])) { + // value is not null + self->WriteValue(position); + bit_util::SetBit(out_is_valid, position); + ++valid_count; + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else if (block.popcount > 0) { + // Slow path: some but not all indices are null. Since we are doing + // random access in general we have to check the value nullness one by + // one. + for (int64_t i = 0; i < block.length; ++i) { + ARROW_COMPILER_ASSUME(src_validity.buffers[0].data != nullptr); + ARROW_COMPILER_ASSUME(idx_validity.buffers[0].data != nullptr); + if (idx_validity.IsValid(position) && src_validity.IsValid(idx[position])) { + // index is not null && value is not null + self->WriteValue(position); + bit_util::SetBit(out_is_valid, position); + ++valid_count; + } else if constexpr (!kOutputIsZeroInitialized) { + self->WriteZero(position); + } + ++position; + } + } else { + if constexpr (!kOutputIsZeroInitialized) { + self->WriteZeroSegment(position, block.length); + } + position += block.length; + } + } + } + return valid_count; + } +}; + +template <int kValueWidthInBits, typename IndexCType, bool WithFactor, + std::enable_if_t<kValueWidthInBits % 8 == 0 || kValueWidthInBits == 1, bool> = + true> +class Gather : public GatherBaseCRTP<Gather<kValueWidthInBits, IndexCType, WithFactor>> { + public: + static constexpr int kValueWidth = kValueWidthInBits / 8; + + private: + const int64_t src_length_; // number of elements of kValueWidth bytes in src_ + const uint8_t* src_; + const int64_t idx_length_; // number IndexCType elements in idx_ + const IndexCType* idx_; + uint8_t* out_; + size_t factor_; Review Comment: It would be nice to keep our convention of using `int32_t` or `int64_t` here. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
