felipecrv commented on code in PR #34195:
URL: https://github.com/apache/arrow/pull/34195#discussion_r1130216822


##########
cpp/src/arrow/compute/kernels/vector_run_end_encode.cc:
##########
@@ -0,0 +1,672 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <utility>
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/common_internal.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/ree_util.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+template <typename ArrowType, bool has_validity_buffer, typename Enable = void>
+struct ReadValueImpl {};
+
+// Numeric and primitive C-compatible types
+template <typename ArrowType, bool has_validity_buffer>
+struct ReadValueImpl<ArrowType, has_validity_buffer, 
enable_if_has_c_type<ArrowType>> {
+  using CType = typename ArrowType::c_type;
+
+  [[nodiscard]] bool ReadValue(const uint8_t* input_validity, const void* 
input_values,
+                               CType* out, int64_t read_offset) const {
+    bool valid = true;
+    if constexpr (has_validity_buffer) {
+      valid = bit_util::GetBit(input_validity, read_offset);
+    }
+    if (valid) {
+      *out = (reinterpret_cast<const CType*>(input_values))[read_offset];
+    }
+    return valid;
+  }
+};
+
+// Boolean w/ validity_bitmap
+template <>
+bool ReadValueImpl<BooleanType, true>::ReadValue(const uint8_t* input_validity,
+                                                 const void* input_values, 
CType* out,
+                                                 int64_t read_offset) const {
+  const bool valid = bit_util::GetBit(input_validity, read_offset);
+  *out = valid &&
+         bit_util::GetBit(reinterpret_cast<const uint8_t*>(input_values), 
read_offset);
+  return valid;
+}
+
+// Boolean w/o validity_bitmap
+template <>
+bool ReadValueImpl<BooleanType, false>::ReadValue(const uint8_t* 
input_validity,
+                                                  const void* input_values, 
CType* out,
+                                                  int64_t read_offset) const {
+  *out = bit_util::GetBit(reinterpret_cast<const uint8_t*>(input_values), 
read_offset);
+  return true;
+}
+
+template <typename ArrowType, bool has_validity_buffer, typename Enable = void>
+struct WriteValueImpl {};
+
+// Numeric and primitive C-compatible types
+template <typename ArrowType, bool has_validity_buffer>
+struct WriteValueImpl<ArrowType, has_validity_buffer, 
enable_if_has_c_type<ArrowType>> {
+  using CType = typename ArrowType::c_type;
+
+  void WriteValue(uint8_t* output_validity, void* output_values, int64_t 
write_offset,
+                  bool valid, CType value) const {
+    if constexpr (has_validity_buffer) {
+      bit_util::SetBitsTo(output_validity, write_offset, 1, valid);
+    }
+    (reinterpret_cast<CType*>(output_values))[write_offset] = value;
+  }
+
+  void WriteRun(uint8_t* output_validity, void* output_values, int64_t 
write_offset,
+                int64_t run_length, bool valid, CType value) const {
+    if constexpr (has_validity_buffer) {
+      bit_util::SetBitsTo(output_validity, write_offset, run_length, valid);
+    }
+    auto* output_values_c = reinterpret_cast<CType*>(output_values);
+    std::fill(output_values_c + write_offset, output_values_c + write_offset + 
run_length,
+              value);
+  }
+};
+
+// Boolean w/ validity_bitmap
+template <>
+void WriteValueImpl<BooleanType, true>::WriteValue(uint8_t* output_validity,
+                                                   void* output_values,
+                                                   int64_t write_offset, bool 
valid,
+                                                   CType value) const {
+  bit_util::SetBitTo(output_validity, write_offset, valid);
+  if (valid) {
+    bit_util::SetBitTo(reinterpret_cast<uint8_t*>(output_values), 
write_offset, value);
+  }
+}
+
+template <>
+void WriteValueImpl<BooleanType, true>::WriteRun(uint8_t* output_validity,
+                                                 void* output_values,
+                                                 int64_t write_offset, int64_t 
run_length,
+                                                 bool valid, CType value) 
const {
+  bit_util::SetBitsTo(output_validity, write_offset, run_length, valid);
+  if (valid) {
+    bit_util::SetBitsTo(reinterpret_cast<uint8_t*>(output_values), 
write_offset,
+                        run_length, value);
+  }
+}
+
+// Boolean w/o validity_bitmap
+template <>
+void WriteValueImpl<BooleanType, false>::WriteValue(uint8_t*, void* 
output_values,
+                                                    int64_t write_offset, bool,
+                                                    CType value) const {
+  bit_util::SetBitTo(reinterpret_cast<uint8_t*>(output_values), write_offset, 
value);
+}
+
+template <>
+void WriteValueImpl<BooleanType, false>::WriteRun(uint8_t*, void* 
output_values,
+                                                  int64_t write_offset,
+                                                  int64_t run_length, bool,
+                                                  CType value) const {
+  bit_util::SetBitsTo(reinterpret_cast<uint8_t*>(output_values), write_offset, 
run_length,
+                      value);
+}
+
+struct RunEndEncondingState : public KernelState {
+  explicit RunEndEncondingState(std::shared_ptr<DataType> run_end_type)
+      : run_end_type{std::move(run_end_type)} {}
+
+  ~RunEndEncondingState() override = default;
+
+  std::shared_ptr<DataType> run_end_type;
+};
+
+template <typename RunEndType, typename ValueType, bool has_validity_buffer>
+class RunEndEncodingLoop {
+ public:
+  using RunEndCType = typename RunEndType::c_type;
+  using CType = typename ValueType::c_type;
+
+ private:
+  const int64_t input_length_;
+  const int64_t input_offset_;
+
+  const uint8_t* input_validity_;
+  const void* input_values_;
+
+  // Needed only by WriteEncodedRuns()
+  uint8_t* output_validity_;
+  void* output_values_;
+  RunEndCType* output_run_ends_;
+
+ public:
+  RunEndEncodingLoop(int64_t input_length, int64_t input_offset,
+                     const uint8_t* input_validity, const void* input_values,
+                     uint8_t* output_validity = NULLPTR, void* output_values = 
NULLPTR,
+                     RunEndCType* output_run_ends = NULLPTR)
+      : input_length_(input_length),
+        input_offset_(input_offset),
+        input_validity_(input_validity),
+        input_values_(input_values),
+        output_validity_(output_validity),
+        output_values_(output_values),
+        output_run_ends_(output_run_ends) {
+    DCHECK_GT(input_length, 0);
+  }
+
+ private:
+  [[nodiscard]] inline bool ReadValue(CType* out, int64_t read_offset) const {
+    return ReadValueImpl<ValueType, has_validity_buffer>{}.ReadValue(
+        input_validity_, input_values_, out, read_offset);
+  }
+
+  inline void WriteValue(int64_t write_offset, bool valid, CType value) {
+    WriteValueImpl<ValueType, has_validity_buffer>{}.WriteValue(
+        output_validity_, output_values_, write_offset, valid, value);
+  }
+
+ public:
+  /// \brief Give a pass over the input data and count the number of runs
+  ///
+  /// \return a pair with the number of non-null run values and total number 
of runs
+  ARROW_NOINLINE std::pair<int64_t, int64_t> CountNumberOfRuns() const {
+    int64_t read_offset = input_offset_;
+    CType current_run;
+    bool current_run_valid = ReadValue(&current_run, read_offset);
+    read_offset += 1;
+    int64_t num_valid_runs = current_run_valid ? 1 : 0;
+    int64_t num_output_runs = 1;
+    for (; read_offset < input_offset_ + input_length_; read_offset += 1) {
+      CType value;
+      const bool valid = ReadValue(&value, read_offset);
+
+      const bool open_new_run = valid != current_run_valid || value != 
current_run;
+      if (open_new_run) {
+        // Open the new run
+        current_run = value;
+        current_run_valid = valid;
+        // Count the new run
+        num_output_runs += 1;
+        num_valid_runs += valid ? 1 : 0;
+      }
+    }
+    return std::make_pair(num_valid_runs, num_output_runs);
+  }
+
+  ARROW_NOINLINE int64_t WriteEncodedRuns() {
+    DCHECK(output_values_);
+    DCHECK(output_run_ends_);
+    int64_t read_offset = input_offset_;
+    int64_t write_offset = 0;
+    CType current_run;
+    bool current_run_valid = ReadValue(&current_run, read_offset);
+    read_offset += 1;
+    for (; read_offset < input_offset_ + input_length_; read_offset += 1) {
+      CType value;
+      const bool valid = ReadValue(&value, read_offset);
+
+      const bool open_new_run = valid != current_run_valid || value != 
current_run;
+      if (open_new_run) {
+        // Close the current run first by writing it out
+        WriteValue(write_offset, current_run_valid, current_run);
+        const int64_t run_end = read_offset - input_offset_;
+        output_run_ends_[write_offset] = static_cast<RunEndCType>(run_end);
+        write_offset += 1;
+        // Open the new run
+        current_run_valid = valid;
+        current_run = value;
+      }
+    }
+    WriteValue(write_offset, current_run_valid, current_run);
+    DCHECK_EQ(input_length_, read_offset - input_offset_);
+    output_run_ends_[write_offset] = static_cast<RunEndCType>(input_length_);
+    return write_offset + 1;
+  }
+};
+
+template <typename RunEndType>
+Status ValidateRunEndType(int64_t input_length) {
+  using RunEndCType = typename RunEndType::c_type;
+  constexpr int64_t kRunEndMax = std::numeric_limits<RunEndCType>::max();
+  if (input_length > kRunEndMax) {
+    return Status::Invalid(
+        "Cannot run-end encode Arrays with more elements than the "
+        "run end type can hold: ",
+        kRunEndMax);
+  }
+  return Status::OK();
+}
+
+template <typename RunEndType, typename ValueType, bool has_validity_buffer>
+class RunEndEncodeImpl {
+ private:
+  KernelContext* ctx_;
+  const ArraySpan& input_array_;
+  ExecResult* output_;
+
+ public:
+  using RunEndCType = typename RunEndType::c_type;
+  using CType = typename ValueType::c_type;
+
+  RunEndEncodeImpl(KernelContext* ctx, const ArraySpan& input_array, 
ExecResult* out)
+      : ctx_{ctx}, input_array_{input_array}, output_{out} {}
+
+  Status Exec() {
+    const int64_t input_length = input_array_.length;
+    const int64_t input_offset = input_array_.offset;
+    const auto* input_validity = input_array_.buffers[0].data;
+    const auto* input_values = input_array_.buffers[1].data;
+
+    // First pass: count the number of runs
+    int64_t num_valid_runs = 0;
+    int64_t num_output_runs = 0;
+    if (input_length > 0) {
+      RETURN_NOT_OK(ValidateRunEndType<RunEndType>(input_length));
+
+      RunEndEncodingLoop<RunEndType, ValueType, has_validity_buffer> 
counting_loop(
+          input_array_.length, input_array_.offset, input_validity, 
input_values);
+      std::tie(num_valid_runs, num_output_runs) = 
counting_loop.CountNumberOfRuns();
+    }
+
+    // Allocate the output array data
+    std::shared_ptr<ArrayData> output_array_data;
+    int64_t validity_buffer_size = 0;  // in bytes
+    {
+      ARROW_ASSIGN_OR_RAISE(auto run_ends_buffer,
+                            AllocateBuffer(num_output_runs * 
RunEndType().bit_width(),
+                                           ctx_->memory_pool()));
+      std::shared_ptr<Buffer> validity_buffer = NULLPTR;
+      if constexpr (has_validity_buffer) {
+        validity_buffer_size = bit_util::BytesForBits(num_output_runs);
+        ARROW_ASSIGN_OR_RAISE(validity_buffer,
+                              AllocateBuffer(validity_buffer_size, 
ctx_->memory_pool()));
+      }
+      ARROW_ASSIGN_OR_RAISE(auto values_buffer,
+                            AllocateBuffer(bit_util::BytesForBits(
+                                               num_output_runs * 
ValueType().bit_width()),
+                                           ctx_->memory_pool()));
+
+      auto ree_type = std::make_shared<RunEndEncodedType>(
+          std::make_shared<RunEndType>(), input_array_.type->GetSharedPtr());
+      auto run_ends_data =
+          ArrayData::Make(ree_type->run_end_type(), num_output_runs,
+                          {NULLPTR, std::move(run_ends_buffer)}, 
/*null_count=*/0);
+      auto values_data =
+          ArrayData::Make(ree_type->value_type(), num_output_runs,
+                          {std::move(validity_buffer), 
std::move(values_buffer)},
+                          /*null_count=*/num_output_runs - num_valid_runs);
+
+      output_array_data =
+          ArrayData::Make(std::move(ree_type), input_length, {NULLPTR},
+                          {std::move(run_ends_data), std::move(values_data)},
+                          /*null_count=*/0);
+    }
+
+    if (input_length > 0) {
+      // Initialize the output pointers
+      auto* output_run_ends =
+          output_array_data->child_data[0]->template 
GetMutableValues<RunEndCType>(1);
+      auto* output_validity =
+          output_array_data->child_data[1]->template 
GetMutableValues<uint8_t>(0);
+      auto* output_values =
+          output_array_data->child_data[1]->template 
GetMutableValues<uint8_t>(1);
+
+      if constexpr (has_validity_buffer) {
+        // Clear last byte in validity buffer to ensure padding bits are zeroed
+        output_validity[validity_buffer_size - 1] = 0;
+      }
+
+      // Second pass: write the runs
+      RunEndEncodingLoop<RunEndType, ValueType, has_validity_buffer> 
writing_loop(
+          input_length, input_offset, input_validity, input_values, 
output_validity,
+          output_values, output_run_ends);
+      [[maybe_unused]] int64_t num_written_runs = 
writing_loop.WriteEncodedRuns();
+      DCHECK_EQ(num_written_runs, num_output_runs);
+    }
+
+    output_->value = std::move(output_array_data);
+    return Status::OK();
+  }
+};
+
+template <typename RunEndType>
+Status RunEndEncodeNullArray(KernelContext* ctx, const ArraySpan& input_array,
+                             ExecResult* output) {
+  using RunEndCType = typename RunEndType::c_type;
+
+  const int64_t input_length = input_array.length;
+  auto input_array_type = input_array.type->GetSharedPtr();
+  DCHECK(input_array_type->id() == Type::NA);
+
+  int64_t num_output_runs = 0;
+  if (input_length > 0) {
+    // Abort if run-end type cannot hold the input length
+    RETURN_NOT_OK(ValidateRunEndType<RunEndType>(input_array.length));
+    num_output_runs = 1;
+  }
+
+  // Allocate the output array data
+  std::shared_ptr<ArrayData> output_array_data;
+  {
+    ARROW_ASSIGN_OR_RAISE(
+        auto run_ends_buffer,
+        AllocateBuffer(num_output_runs * RunEndType().bit_width(), 
ctx->memory_pool()));
+
+    auto ree_type = 
std::make_shared<RunEndEncodedType>(std::make_shared<RunEndType>(),
+                                                        input_array_type);
+    auto run_ends_data = ArrayData::Make(std::make_shared<RunEndType>(), 
num_output_runs,
+                                         {NULLPTR, std::move(run_ends_buffer)},
+                                         /*null_count=*/0);
+    auto values_data = ArrayData::Make(input_array_type, num_output_runs, 
{NULLPTR},
+                                       /*null_count=*/num_output_runs);
+
+    output_array_data =
+        ArrayData::Make(std::move(ree_type), input_length, {NULLPTR},
+                        {std::move(run_ends_data), std::move(values_data)},
+                        /*null_count=*/0);
+  }

Review Comment:
   @zeroshade early-returning for the 0 case means repeating this block. Don't 
you think it's more elegant that this code can handle any length and as such 
doesn't have to be duplicated?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to