pitrou commented on code in PR #47294: URL: https://github.com/apache/arrow/pull/47294#discussion_r2319532082
########## cpp/src/arrow/util/rle_encoding_internal.h: ########## @@ -299,385 +552,988 @@ class RleEncoder { uint8_t* literal_indicator_byte_; }; +/************************* + * RleBitPackedDecoder * + *************************/ + +template <typename T> +RleBitPackedDecoder<T>::RleBitPackedDecoder(raw_data_const_pointer data, + raw_data_size_type data_size, + bit_size_type value_bit_width) noexcept { + Reset(data, data_size, value_bit_width); +} + +template <typename T> +void RleBitPackedDecoder<T>::Reset(raw_data_const_pointer data, + raw_data_size_type data_size, + bit_size_type value_bit_width) noexcept { + ARROW_DCHECK_GE(value_bit_width, 0); + ARROW_DCHECK_LE(value_bit_width, 64); + parser_.Reset(data, data_size, value_bit_width); + decoder_ = {}; +} + +template <typename T> +auto RleBitPackedDecoder<T>::RunRemaining() const -> values_count_type { + return std::visit([](auto const& dec) { return dec.Remaining(); }, decoder_); +} + +template <typename T> +bool RleBitPackedDecoder<T>::Exhausted() const { + return (RunRemaining() == 0) && parser_.Exhausted(); +} + template <typename T> -inline bool RleDecoder::Get(T* val) { +bool RleBitPackedDecoder<T>::ParseAndResetDecoder() { + auto dyn_run = parser_.Next(); + if (!dyn_run.has_value()) { + return false; + } + + if (auto* rle_run = std::get_if<BitPackedRun>(dyn_run.operator->())) { + decoder_ = {BitPackedDecoder<value_type>(*rle_run)}; + return true; + } + + auto* bit_packed_run = std::get_if<RleRun>(dyn_run.operator->()); + ARROW_DCHECK(bit_packed_run); // Only two possibilities in the variant + decoder_ = {RleDecoder<value_type>(*bit_packed_run)}; + return true; +} + +template <typename T> +auto RleBitPackedDecoder<T>::RunGetBatch(value_type* out, values_count_type batch_size) + -> values_count_type { + return std::visit([&](auto& dec) { return dec.GetBatch(out, batch_size); }, decoder_); +} + +template <typename T> +bool RleBitPackedDecoder<T>::Get(value_type* val) { return GetBatch(val, 1) == 1; } +namespace internal { + +/// A ``Parse`` handler that calls a single lambda. +/// +/// This lambda would typically take the input run as ``auto run`` (i.e. the lambda is +/// templated) and deduce other types from it. +template <typename Lambda> +struct LambdaHandler { + Lambda handlder_; + + auto OnBitPackedRun(BitPackedRun run) { return handlder_(std::move(run)); } + + auto OnRleRun(RleRun run) { return handlder_(std::move(run)); } +}; + +template <typename Lambda> +LambdaHandler(Lambda) -> LambdaHandler<Lambda>; + +template <typename value_type, typename Run> +struct decoder_for; + +template <typename value_type> +struct decoder_for<value_type, BitPackedRun> { + using type = BitPackedDecoder<value_type>; +}; + +template <typename value_type> +struct decoder_for<value_type, RleRun> { + using type = RleDecoder<value_type>; +}; + +template <typename value_type, typename Run> +using decoder_for_t = typename decoder_for<value_type, Run>::type; + +} // namespace internal + template <typename T> -inline int RleDecoder::GetBatch(T* values, int batch_size) { - ARROW_DCHECK_GE(bit_width_, 0); - int values_read = 0; - - auto* out = values; - - while (values_read < batch_size) { - int remaining = batch_size - values_read; - - if (repeat_count_ > 0) { // Repeated value case. - int repeat_batch = std::min(remaining, repeat_count_); - std::fill(out, out + repeat_batch, static_cast<T>(current_value_)); - - repeat_count_ -= repeat_batch; - values_read += repeat_batch; - out += repeat_batch; - } else if (literal_count_ > 0) { - int literal_batch = std::min(remaining, literal_count_); - int actual_read = bit_reader_.GetBatch(bit_width_, out, literal_batch); - if (actual_read != literal_batch) { - return values_read; - } +auto RleBitPackedDecoder<T>::GetBatch(value_type* out, values_count_type batch_size) + -> values_count_type { + using ControlFlow = RleBitPackedParser::ControlFlow; - literal_count_ -= literal_batch; - values_read += literal_batch; - out += literal_batch; - } else { - if (!NextCounts<T>()) return values_read; + values_count_type values_read = 0; + + // Remaining from a previous call that would have left some unread data from a run. + if (ARROW_PREDICT_FALSE(RunRemaining() > 0)) { + auto const read = RunGetBatch(out, batch_size); + values_read += read; + out += read; + + // Either we fulfilled all the batch to be read or we finished remaining run. + if (ARROW_PREDICT_FALSE(values_read == batch_size)) { + return values_read; } + ARROW_DCHECK(RunRemaining() == 0); } + auto handler = internal::LambdaHandler{ + [&](auto run) { + ARROW_DCHECK_LT(values_read, batch_size); + internal::decoder_for_t<value_type, decltype(run)> decoder(run); + auto const read = decoder.GetBatch(out, batch_size - values_read); + ARROW_DCHECK_LE(read, batch_size - values_read); + values_read += read; + out += read; + + // Stop reading and store remaining decoder + if (ARROW_PREDICT_FALSE(values_read == batch_size || read == 0)) { + decoder_ = std::move(decoder); + return ControlFlow::Break; + } + + return ControlFlow::Continue; + }, + }; + + parser_.Parse(handler); + return values_read; } -template <typename T, typename RunType, typename Converter> -inline int RleDecoder::GetSpaced(Converter converter, int batch_size, int null_count, - const uint8_t* valid_bits, int64_t valid_bits_offset, - T* out) { - if (ARROW_PREDICT_FALSE(null_count == batch_size)) { - converter.FillZero(out, out + batch_size); - return batch_size; +namespace internal { + +/// Utility class to safely handle values and null count without too error-prone +/// verbosity. +class BatchCounter { + public: + using size_type = int32_t; + + [[nodiscard]] static constexpr BatchCounter FromBatchSizeAndNulls( + size_type batch_size, size_type null_count) { + ARROW_DCHECK_LE(null_count, batch_size); + return {batch_size - null_count, null_count}; } - ARROW_DCHECK_GE(bit_width_, 0); - int values_read = 0; - int values_remaining = batch_size - null_count; + constexpr BatchCounter(size_type values_count, size_type null_count) noexcept + : values_count_(values_count), null_count_(null_count) {} - // Assume no bits to start. - arrow::internal::BitRunReader bit_reader(valid_bits, valid_bits_offset, - /*length=*/batch_size); - arrow::internal::BitRun valid_run = bit_reader.NextRun(); - while (values_read < batch_size) { - if (ARROW_PREDICT_FALSE(valid_run.length == 0)) { - valid_run = bit_reader.NextRun(); + [[nodiscard]] constexpr size_type ValuesCount() const noexcept { return values_count_; } + + [[nodiscard]] constexpr size_type ValuesRead() const noexcept { return values_read_; } + + [[nodiscard]] constexpr size_type ValuesRemaining() const noexcept { + ARROW_DCHECK_LE(values_read_, values_count_); + return values_count_ - values_read_; + } + + constexpr void AccrueReadValues(size_type to_read) noexcept { + ARROW_DCHECK_LE(to_read, ValuesRemaining()); + values_read_ += to_read; + } + + [[nodiscard]] constexpr size_type NullCount() const noexcept { return null_count_; } + + [[nodiscard]] constexpr size_type NullRead() const noexcept { return null_read_; } + + [[nodiscard]] constexpr size_type NullRemaining() const noexcept { + ARROW_DCHECK_LE(null_read_, null_count_); + return null_count_ - null_read_; + } + + constexpr void AccrueReadNulls(size_type to_read) noexcept { + ARROW_DCHECK_LE(to_read, NullRemaining()); + null_read_ += to_read; + } + + [[nodiscard]] constexpr size_type TotalRemaining() const noexcept { + return ValuesRemaining() + NullRemaining(); + } + + [[nodiscard]] constexpr size_type TotalRead() const noexcept { + return values_read_ + null_read_; + } + + [[nodiscard]] constexpr bool IsFullyNull() const noexcept { + return ValuesRemaining() == 0; + } + + [[nodiscard]] constexpr bool IsDone() const noexcept { return TotalRemaining() == 0; } + + private: + size_type values_count_ = 0; + size_type values_read_ = 0; + size_type null_count_ = 0; + size_type null_read_ = 0; +}; + +// The maximal unsigned size that a variable can fit. +template <typename T> +constexpr auto max_size_for_v = + static_cast<std::make_unsigned_t<T>>(std::numeric_limits<T>::max()); + +/// Overload for GetSpaced for a single run in a RleDecoder +template <typename Converter, typename BitRunReader, typename BitRun, + typename values_count_type, typename value_type> +auto RunGetSpaced(Converter& converter, typename Converter::out_type* out, + values_count_type batch_size, values_count_type null_count, + BitRunReader&& validity_reader, BitRun&& validity_run, + RleDecoder<value_type>& decoder) + -> std::pair<values_count_type, values_count_type> { + ARROW_DCHECK_GT(batch_size, 0); + // The equality case is handled in the main loop in GetSpaced + ARROW_DCHECK_LT(null_count, batch_size); + + auto batch = BatchCounter::FromBatchSizeAndNulls(batch_size, null_count); + + values_count_type const values_available = decoder.Remaining(); + ARROW_DCHECK_GT(values_available, 0); + auto values_remaining_run = [&]() { + auto out = values_available - batch.ValuesRead(); + ARROW_DCHECK_GE(out, 0); + return out; + }; + + // Consume as much as possible from the repeated run. + // We only need to count the number of nulls and non-nulls because we can fill in the + // same value for nulls and non-nulls. + // This proves to be a big efficiency win. + while (values_remaining_run() > 0 && !batch.IsDone()) { + ARROW_DCHECK_GE(validity_run.length, 0); + ARROW_DCHECK_LT(validity_run.length, max_size_for_v<values_count_type>); + ARROW_DCHECK_LE(validity_run.length, batch.TotalRemaining()); + auto const& validity_run_size = static_cast<values_count_type>(validity_run.length); + + if (validity_run.set) { + // We may end the current RLE run in the middle of the validity run + auto update_size = std::min(validity_run_size, values_remaining_run()); + batch.AccrueReadValues(update_size); + validity_run.length -= update_size; + } else { + // We can consume all nulls here because it does not matter if we consume on this + // RLE run, or an a next encoded run. The value filled does not matter. + auto update_size = std::min(validity_run_size, batch.NullRemaining()); + batch.AccrueReadNulls(update_size); + validity_run.length -= update_size; + } + + if (ARROW_PREDICT_TRUE(validity_run.length == 0)) { + validity_run = validity_reader.NextRun(); } + } + + value_type const value = decoder.Value(); + if (ARROW_PREDICT_FALSE(!converter.InputIsValid(value))) { + return {0, 0}; + } + converter.WriteRepeated(out, out + batch.TotalRead(), value); + auto const actual_values_read = decoder.Advance(batch.ValuesRead()); + // We always cropped the number of values_read by the remaining values in the run. + // What's more the RLE decoder should not encounter any errors. + ARROW_DCHECK_EQ(actual_values_read, batch.ValuesRead()); - ARROW_DCHECK_GT(batch_size, 0); - ARROW_DCHECK_GT(valid_run.length, 0); + return {batch.ValuesRead(), batch.NullRead()}; +} + +template <typename T, typename... Ts> +[[nodiscard]] constexpr T min(T x, Ts... ys) { + ((x = std::min(x, ys)), ...); + return x; +} + +static_assert(min(5) == 5); +static_assert(min(5, 4, -1) == -1); +static_assert(min(5, 41) == 5); + +template <typename Converter, typename BitRunReader, typename BitRun, + typename values_count_type, typename value_type> +auto RunGetSpaced(Converter& converter, typename Converter::out_type* out, + values_count_type batch_size, values_count_type null_count, + BitRunReader&& validity_reader, BitRun&& validity_run, + BitPackedDecoder<value_type>& decoder) + -> std::pair<values_count_type, values_count_type> { + ARROW_DCHECK_GT(batch_size, 0); + // The equality case is handled in the main loop in GetSpaced + ARROW_DCHECK_LT(null_count, batch_size); + + auto batch = BatchCounter::FromBatchSizeAndNulls(batch_size, null_count); + + values_count_type const values_available = decoder.Remaining(); + ARROW_DCHECK_GT(values_available, 0); + auto run_values_remaining = [&]() { + auto out = values_available - batch.ValuesRead(); + ARROW_DCHECK_GE(out, 0); + return out; + }; + + while (run_values_remaining() > 0 && batch.ValuesRemaining() > 0) { Review Comment: For simplification, can we instead instantiate `batch` like this: ```c++ auto batch = BatchCounter::FromBatchSizeAndNulls( std::min(batch_size, values_available + null_count), null_count); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org