bkietz commented on a change in pull request #9621:
URL: https://github.com/apache/arrow/pull/9621#discussion_r596406429



##########
File path: cpp/src/arrow/compute/kernels/aggregate_basic.cc
##########
@@ -229,6 +604,710 @@ std::unique_ptr<KernelState> AllInit(KernelContext*, 
const KernelInitArgs& args)
   return ::arrow::internal::make_unique<BooleanAllImpl>();
 }
 
+struct GroupByImpl : public ScalarAggregator {
+  using AddLengthImpl = std::function<void(const std::shared_ptr<ArrayData>&, 
int32_t*)>;
+
+  struct GetAddLengthImpl {
+    static constexpr int32_t null_extra_byte = 1;
+
+    static void AddFixedLength(int32_t fixed_length, int64_t num_repeats,
+                               int32_t* lengths) {
+      for (int64_t i = 0; i < num_repeats; ++i) {
+        lengths[i] += fixed_length + null_extra_byte;
+      }
+    }
+
+    static void AddVarLength(const std::shared_ptr<ArrayData>& data, int32_t* 
lengths) {
+      using offset_type = typename StringType::offset_type;
+      constexpr int32_t length_extra_bytes = sizeof(offset_type);
+      auto offset = data->offset;
+      const auto offsets = data->GetValues<offset_type>(1);
+      if (data->MayHaveNulls()) {
+        const uint8_t* nulls = data->buffers[0]->data();
+
+        for (int64_t i = 0; i < data->length; ++i) {
+          bool is_null = !BitUtil::GetBit(nulls, offset + i);
+          if (is_null) {
+            lengths[i] += null_extra_byte + length_extra_bytes;
+          } else {
+            lengths[i] += null_extra_byte + length_extra_bytes + 
offsets[offset + i + 1] -
+                          offsets[offset + i];
+          }
+        }
+      } else {
+        for (int64_t i = 0; i < data->length; ++i) {
+          lengths[i] += null_extra_byte + length_extra_bytes + offsets[offset 
+ i + 1] -
+                        offsets[offset + i];
+        }
+      }
+    }
+
+    template <typename T>
+    Status Visit(const T& input_type) {
+      int32_t num_bytes = (bit_width(input_type.id()) + 7) / 8;
+      add_length_impl = [num_bytes](const std::shared_ptr<ArrayData>& data,
+                                    int32_t* lengths) {
+        AddFixedLength(num_bytes, data->length, lengths);
+      };
+      return Status::OK();
+    }
+
+    Status Visit(const StringType&) {
+      add_length_impl = [](const std::shared_ptr<ArrayData>& data, int32_t* 
lengths) {
+        AddVarLength(data, lengths);
+      };
+      return Status::OK();
+    }
+
+    Status Visit(const BinaryType&) {
+      add_length_impl = [](const std::shared_ptr<ArrayData>& data, int32_t* 
lengths) {
+        AddVarLength(data, lengths);
+      };
+      return Status::OK();
+    }
+
+    Status Visit(const FixedSizeBinaryType& type) {
+      int32_t num_bytes = type.byte_width();
+      add_length_impl = [num_bytes](const std::shared_ptr<ArrayData>& data,
+                                    int32_t* lengths) {
+        AddFixedLength(num_bytes, data->length, lengths);
+      };
+      return Status::OK();
+    }
+
+    AddLengthImpl add_length_impl;
+  };
+
+  using EncodeNextImpl =
+      std::function<void(const std::shared_ptr<ArrayData>&, uint8_t**)>;
+
+  struct GetEncodeNextImpl {
+    template <int NumBits>
+    static void EncodeSmallFixed(const std::shared_ptr<ArrayData>& data,
+                                 uint8_t** encoded_bytes) {
+      auto raw_input = data->buffers[1]->data();
+      auto offset = data->offset;
+      if (data->MayHaveNulls()) {
+        const uint8_t* nulls = data->buffers[0]->data();
+        for (int64_t i = 0; i < data->length; ++i) {
+          auto& encoded_ptr = encoded_bytes[i];
+          bool is_null = !BitUtil::GetBit(nulls, offset + i);
+          encoded_ptr[0] = is_null ? 1 : 0;
+          encoded_ptr += 1;
+          uint64_t null_multiplier = is_null ? 0 : 1;
+          if (NumBits == 1) {
+            encoded_ptr[0] = static_cast<uint8_t>(
+                null_multiplier * (BitUtil::GetBit(raw_input, offset + i) ? 1 
: 0));
+            encoded_ptr += 1;
+          }
+          if (NumBits == 8) {
+            encoded_ptr[0] =
+                static_cast<uint8_t>(null_multiplier * reinterpret_cast<const 
uint8_t*>(
+                                                           raw_input)[offset + 
i]);
+            encoded_ptr += 1;
+          }
+          if (NumBits == 16) {
+            reinterpret_cast<uint16_t*>(encoded_ptr)[0] =
+                static_cast<uint16_t>(null_multiplier * reinterpret_cast<const 
uint16_t*>(
+                                                            raw_input)[offset 
+ i]);
+            encoded_ptr += 2;
+          }
+          if (NumBits == 32) {
+            reinterpret_cast<uint32_t*>(encoded_ptr)[0] =
+                static_cast<uint32_t>(null_multiplier * reinterpret_cast<const 
uint32_t*>(
+                                                            raw_input)[offset 
+ i]);
+            encoded_ptr += 4;
+          }
+          if (NumBits == 64) {
+            reinterpret_cast<uint64_t*>(encoded_ptr)[0] =
+                static_cast<uint64_t>(null_multiplier * reinterpret_cast<const 
uint64_t*>(
+                                                            raw_input)[offset 
+ i]);
+            encoded_ptr += 8;
+          }
+        }
+      } else {
+        for (int64_t i = 0; i < data->length; ++i) {
+          auto& encoded_ptr = encoded_bytes[i];
+          encoded_ptr[0] = 0;
+          encoded_ptr += 1;
+          if (NumBits == 1) {
+            encoded_ptr[0] = (BitUtil::GetBit(raw_input, offset + i) ? 1 : 0);
+            encoded_ptr += 1;
+          }
+          if (NumBits == 8) {
+            encoded_ptr[0] = reinterpret_cast<const 
uint8_t*>(raw_input)[offset + i];
+            encoded_ptr += 1;
+          }
+          if (NumBits == 16) {
+            reinterpret_cast<uint16_t*>(encoded_ptr)[0] =
+                reinterpret_cast<const uint16_t*>(raw_input)[offset + i];
+            encoded_ptr += 2;
+          }
+          if (NumBits == 32) {
+            reinterpret_cast<uint32_t*>(encoded_ptr)[0] =
+                reinterpret_cast<const uint32_t*>(raw_input)[offset + i];
+            encoded_ptr += 4;
+          }
+          if (NumBits == 64) {
+            reinterpret_cast<uint64_t*>(encoded_ptr)[0] =
+                reinterpret_cast<const uint64_t*>(raw_input)[offset + i];
+            encoded_ptr += 8;
+          }
+        }
+      }
+    }
+
+    static void EncodeBigFixed(int num_bytes, const 
std::shared_ptr<ArrayData>& data,
+                               uint8_t** encoded_bytes) {
+      auto raw_input = data->buffers[1]->data();
+      auto offset = data->offset;
+      if (data->MayHaveNulls()) {
+        const uint8_t* nulls = data->buffers[0]->data();
+        for (int64_t i = 0; i < data->length; ++i) {
+          auto& encoded_ptr = encoded_bytes[i];
+          bool is_null = !BitUtil::GetBit(nulls, offset + i);
+          encoded_ptr[0] = is_null ? 1 : 0;
+          encoded_ptr += 1;
+          if (is_null) {
+            memset(encoded_ptr, 0, num_bytes);
+          } else {
+            memcpy(encoded_ptr, raw_input + num_bytes * (offset + i), 
num_bytes);
+          }
+          encoded_ptr += num_bytes;
+        }
+      } else {
+        for (int64_t i = 0; i < data->length; ++i) {
+          auto& encoded_ptr = encoded_bytes[i];
+          encoded_ptr[0] = 0;
+          encoded_ptr += 1;
+          memcpy(encoded_ptr, raw_input + num_bytes * (offset + i), num_bytes);
+          encoded_ptr += num_bytes;
+        }
+      }
+    }
+
+    static void EncodeVarLength(const std::shared_ptr<ArrayData>& data,
+                                uint8_t** encoded_bytes) {
+      using offset_type = typename StringType::offset_type;
+      auto offset = data->offset;
+      const auto offsets = data->GetValues<offset_type>(1);
+      auto raw_input = data->buffers[2]->data();
+      if (data->MayHaveNulls()) {
+        const uint8_t* nulls = data->buffers[0]->data();
+        for (int64_t i = 0; i < data->length; ++i) {
+          auto& encoded_ptr = encoded_bytes[i];
+          bool is_null = !BitUtil::GetBit(nulls, offset + i);
+          if (is_null) {
+            encoded_ptr[0] = 1;
+            encoded_ptr++;
+            reinterpret_cast<offset_type*>(encoded_ptr)[0] = 0;
+            encoded_ptr += sizeof(offset_type);
+          } else {
+            encoded_ptr[0] = 0;
+            encoded_ptr++;
+            size_t num_bytes = offsets[offset + i + 1] - offsets[offset + i];
+            reinterpret_cast<offset_type*>(encoded_ptr)[0] = num_bytes;
+            encoded_ptr += sizeof(offset_type);
+            memcpy(encoded_ptr, raw_input + offsets[offset + i], num_bytes);
+            encoded_ptr += num_bytes;
+          }
+        }
+      } else {
+        for (int64_t i = 0; i < data->length; ++i) {
+          auto& encoded_ptr = encoded_bytes[i];
+          encoded_ptr[0] = 0;
+          encoded_ptr++;
+          size_t num_bytes = offsets[offset + i + 1] - offsets[offset + i];
+          reinterpret_cast<offset_type*>(encoded_ptr)[0] = num_bytes;
+          encoded_ptr += sizeof(offset_type);
+          memcpy(encoded_ptr, raw_input + offsets[offset + i], num_bytes);
+          encoded_ptr += num_bytes;
+        }
+      }
+    }
+
+    template <typename T>
+    Status Visit(const T& input_type) {
+      int32_t num_bits = bit_width(input_type.id());
+      switch (num_bits) {
+        case 1:
+          encode_next_impl = [](const std::shared_ptr<ArrayData>& data,
+                                uint8_t** encoded_bytes) {
+            EncodeSmallFixed<1>(data, encoded_bytes);
+          };
+          break;
+        case 8:
+          encode_next_impl = [](const std::shared_ptr<ArrayData>& data,
+                                uint8_t** encoded_bytes) {
+            EncodeSmallFixed<8>(data, encoded_bytes);
+          };
+          break;
+        case 16:
+          encode_next_impl = [](const std::shared_ptr<ArrayData>& data,
+                                uint8_t** encoded_bytes) {
+            EncodeSmallFixed<16>(data, encoded_bytes);
+          };
+          break;
+        case 32:
+          encode_next_impl = [](const std::shared_ptr<ArrayData>& data,
+                                uint8_t** encoded_bytes) {
+            EncodeSmallFixed<32>(data, encoded_bytes);
+          };
+          break;
+        case 64:
+          encode_next_impl = [](const std::shared_ptr<ArrayData>& data,
+                                uint8_t** encoded_bytes) {
+            EncodeSmallFixed<64>(data, encoded_bytes);
+          };
+          break;
+      }
+      return Status::OK();
+    }
+
+    Status Visit(const StringType&) {
+      encode_next_impl = [](const std::shared_ptr<ArrayData>& data,
+                            uint8_t** encoded_bytes) {
+        EncodeVarLength(data, encoded_bytes);
+      };
+      return Status::OK();
+    }
+
+    Status Visit(const BinaryType&) {
+      encode_next_impl = [](const std::shared_ptr<ArrayData>& data,
+                            uint8_t** encoded_bytes) {
+        EncodeVarLength(data, encoded_bytes);
+      };
+      return Status::OK();
+    }
+
+    Status Visit(const FixedSizeBinaryType& type) {
+      int32_t num_bytes = type.byte_width();
+      encode_next_impl = [num_bytes](const std::shared_ptr<ArrayData>& data,
+                                     uint8_t** encoded_bytes) {
+        EncodeBigFixed(num_bytes, data, encoded_bytes);
+      };
+      return Status::OK();
+    }
+
+    EncodeNextImpl encode_next_impl;
+  };
+
+  using DecodeNextImpl = std::function<void(KernelContext*, int32_t, uint8_t**,
+                                            std::shared_ptr<ArrayData>*)>;
+
+  struct GetDecodeNextImpl {
+    static Status DecodeNulls(KernelContext* ctx, int32_t length, uint8_t** 
encoded_bytes,
+                              std::shared_ptr<ResizableBuffer>* null_buf,
+                              int32_t* null_count) {
+      // Do we have nulls?
+      *null_count = 0;
+      for (int32_t i = 0; i < length; ++i) {
+        *null_count += encoded_bytes[i][0];
+      }
+      if (*null_count > 0) {
+        ARROW_ASSIGN_OR_RAISE(*null_buf, ctx->AllocateBitmap(length));
+        uint8_t* nulls = (*null_buf)->mutable_data();
+        memset(nulls, 0, (*null_buf)->size());
+        for (int32_t i = 0; i < length; ++i) {
+          if (!encoded_bytes[i][0]) {
+            BitUtil::SetBit(nulls, i);
+          }
+          encoded_bytes[i] += 1;
+        }
+      } else {
+        for (int32_t i = 0; i < length; ++i) {
+          encoded_bytes[i] += 1;
+        }
+      }
+      return Status ::OK();
+    }
+
+    template <int NumBits>
+    static void DecodeSmallFixed(KernelContext* ctx, const Type::type& 
output_type,
+                                 int32_t length, uint8_t** encoded_bytes,
+                                 std::shared_ptr<ArrayData>* out) {
+      std::shared_ptr<ResizableBuffer> null_buf;
+      int32_t null_count;
+      KERNEL_RETURN_IF_ERROR(
+          ctx, DecodeNulls(ctx, length, encoded_bytes, &null_buf, 
&null_count));
+
+      KERNEL_ASSIGN_OR_RAISE(
+          auto key_buf, ctx,
+          ctx->Allocate(NumBits == 1 ? (length + 7) / 8 : (NumBits / 8) * 
length));
+
+      uint8_t* raw_output = key_buf->mutable_data();
+      for (int32_t i = 0; i < length; ++i) {
+        auto& encoded_ptr = encoded_bytes[i];
+        if (NumBits == 1) {
+          BitUtil::SetBitTo(raw_output, i, encoded_ptr[0] != 0);
+          encoded_ptr += 1;
+        }
+        if (NumBits == 8) {
+          raw_output[i] = encoded_ptr[0];
+          encoded_ptr += 1;
+        }
+        if (NumBits == 16) {
+          reinterpret_cast<uint16_t*>(raw_output)[i] =
+              reinterpret_cast<const uint16_t*>(encoded_bytes[i])[0];
+          encoded_ptr += 2;
+        }
+        if (NumBits == 32) {
+          reinterpret_cast<uint32_t*>(raw_output)[i] =
+              reinterpret_cast<const uint32_t*>(encoded_bytes[i])[0];
+          encoded_ptr += 4;
+        }
+        if (NumBits == 64) {
+          reinterpret_cast<uint64_t*>(raw_output)[i] =
+              reinterpret_cast<const uint64_t*>(encoded_bytes[i])[0];
+          encoded_ptr += 8;
+        }
+      }
+
+      DCHECK(is_integer(output_type) || output_type == Type::BOOL);
+      *out = ArrayData::Make(int64(), length, {null_buf, key_buf}, null_count);
+    }
+
+    static void DecodeBigFixed(KernelContext* ctx, int num_bytes, int32_t 
length,
+                               uint8_t** encoded_bytes, 
std::shared_ptr<ArrayData>* out) {
+      std::shared_ptr<ResizableBuffer> null_buf;
+      int32_t null_count;
+      KERNEL_RETURN_IF_ERROR(
+          ctx, DecodeNulls(ctx, length, encoded_bytes, &null_buf, 
&null_count));
+
+      KERNEL_ASSIGN_OR_RAISE(auto key_buf, ctx, ctx->Allocate(num_bytes * 
length));
+      auto raw_output = key_buf->mutable_data();
+      for (int32_t i = 0; i < length; ++i) {
+        memcpy(raw_output + i * num_bytes, encoded_bytes[i], num_bytes);
+        encoded_bytes[i] += num_bytes;
+      }
+
+      *out = ArrayData::Make(fixed_size_binary(num_bytes), length, {null_buf, 
key_buf},
+                             null_count);
+    }
+
+    static void DecodeVarLength(KernelContext* ctx, bool is_string, int32_t 
length,
+                                uint8_t** encoded_bytes,
+                                std::shared_ptr<ArrayData>* out) {
+      std::shared_ptr<ResizableBuffer> null_buf;
+      int32_t null_count;
+      KERNEL_RETURN_IF_ERROR(
+          ctx, DecodeNulls(ctx, length, encoded_bytes, &null_buf, 
&null_count));
+
+      using offset_type = typename StringType::offset_type;
+
+      int32_t length_sum = 0;
+      for (int32_t i = 0; i < length; ++i) {
+        length_sum += reinterpret_cast<offset_type*>(encoded_bytes)[0];
+      }
+
+      KERNEL_ASSIGN_OR_RAISE(auto offset_buf, ctx,
+                             ctx->Allocate(sizeof(offset_type) * (1 + 
length)));
+      KERNEL_ASSIGN_OR_RAISE(auto key_buf, ctx, ctx->Allocate(length_sum));
+
+      auto raw_offsets = offset_buf->mutable_data();
+      auto raw_keys = key_buf->mutable_data();
+      int32_t current_offset = 0;
+      for (int32_t i = 0; i < length; ++i) {
+        offset_type key_length = 
reinterpret_cast<offset_type*>(encoded_bytes[i])[0];
+        reinterpret_cast<offset_type*>(raw_offsets)[i] = current_offset;
+        encoded_bytes[i] += sizeof(offset_type);
+        memcpy(raw_keys + current_offset, encoded_bytes[i], key_length);
+        encoded_bytes[i] += key_length;
+        current_offset += key_length;
+      }
+      reinterpret_cast<offset_type*>(raw_offsets)[length] = current_offset;
+
+      if (is_string) {
+        *out = ArrayData::Make(utf8(), length, {null_buf, offset_buf, key_buf},
+                               null_count, 0);
+      } else {
+        *out = ArrayData::Make(binary(), length, {null_buf, offset_buf, 
key_buf},
+                               null_count, 0);
+      }
+    }
+
+    template <typename T>
+    Status Visit(const T& input_type) {
+      int32_t num_bits = bit_width(input_type.id());
+      auto type_id = input_type.id();
+      switch (num_bits) {
+        case 1:
+          decode_next_impl = [type_id](KernelContext* ctx, int32_t length,
+                                       uint8_t** encoded_bytes,
+                                       std::shared_ptr<ArrayData>* out) {
+            DecodeSmallFixed<1>(ctx, type_id, length, encoded_bytes, out);
+          };
+          break;
+        case 8:
+          decode_next_impl = [type_id](KernelContext* ctx, int32_t length,
+                                       uint8_t** encoded_bytes,
+                                       std::shared_ptr<ArrayData>* out) {
+            DecodeSmallFixed<8>(ctx, type_id, length, encoded_bytes, out);
+          };
+          break;
+        case 16:
+          decode_next_impl = [type_id](KernelContext* ctx, int32_t length,
+                                       uint8_t** encoded_bytes,
+                                       std::shared_ptr<ArrayData>* out) {
+            DecodeSmallFixed<16>(ctx, type_id, length, encoded_bytes, out);
+          };
+          break;
+        case 32:
+          decode_next_impl = [type_id](KernelContext* ctx, int32_t length,
+                                       uint8_t** encoded_bytes,
+                                       std::shared_ptr<ArrayData>* out) {
+            DecodeSmallFixed<32>(ctx, type_id, length, encoded_bytes, out);
+          };
+          break;
+        case 64:
+          decode_next_impl = [type_id](KernelContext* ctx, int32_t length,
+                                       uint8_t** encoded_bytes,
+                                       std::shared_ptr<ArrayData>* out) {
+            DecodeSmallFixed<64>(ctx, type_id, length, encoded_bytes, out);
+          };
+          break;
+      }
+      return Status::OK();
+    }
+
+    Status Visit(const StringType&) {
+      decode_next_impl = [](KernelContext* ctx, int32_t length, uint8_t** 
encoded_bytes,
+                            std::shared_ptr<ArrayData>* out) {
+        DecodeVarLength(ctx, true, length, encoded_bytes, out);
+      };
+      return Status::OK();
+    }
+
+    Status Visit(const BinaryType&) {
+      decode_next_impl = [](KernelContext* ctx, int32_t length, uint8_t** 
encoded_bytes,
+                            std::shared_ptr<ArrayData>* out) {
+        DecodeVarLength(ctx, false, length, encoded_bytes, out);
+      };
+      return Status::OK();
+    }
+
+    Status Visit(const FixedSizeBinaryType& type) {
+      int32_t num_bytes = type.byte_width();
+      decode_next_impl = [num_bytes](KernelContext* ctx, int32_t length,
+                                     uint8_t** encoded_bytes,
+                                     std::shared_ptr<ArrayData>* out) {
+        DecodeBigFixed(ctx, num_bytes, length, encoded_bytes, out);
+      };
+      return Status::OK();
+    }
+
+    DecodeNextImpl decode_next_impl;
+  };
+
+  void Consume(KernelContext* ctx, const ExecBatch& batch) override {
+    ArrayDataVector aggregands, keys;
+
+    size_t i;
+    for (i = 0; i < aggregators.size(); ++i) {
+      aggregands.push_back(batch[i].array());
+    }
+    while (i < static_cast<size_t>(batch.num_values())) {
+      keys.push_back(batch[i++].array());
+    }
+
+    offsets_batch_.clear();
+    offsets_batch_.resize(batch.length + 1);
+    offsets_batch_[0] = 0;
+    memset(offsets_batch_.data(), 0, sizeof(offsets_batch_[0]) * 
offsets_batch_.size());
+    for (size_t i = 0; i < keys.size(); ++i) {
+      add_length_impl[i].add_length_impl(keys[i], offsets_batch_.data());
+    }
+    int32_t total_length = 0;
+    for (int64_t i = 0; i < batch.length; ++i) {
+      auto total_length_before = total_length;
+      total_length += offsets_batch_[i];
+      offsets_batch_[i] = total_length_before;
+    }
+    offsets_batch_[batch.length] = total_length;
+
+    key_bytes_batch_.clear();
+    key_bytes_batch_.resize(total_length);
+    key_buf_ptrs_.clear();
+    key_buf_ptrs_.resize(batch.length);
+    for (int64_t i = 0; i < batch.length; ++i) {
+      key_buf_ptrs_[i] = key_bytes_batch_.data() + offsets_batch_[i];
+    }
+    for (size_t i = 0; i < keys.size(); ++i) {
+      encode_next_impl[i].encode_next_impl(keys[i], key_buf_ptrs_.data());
+    }
+
+    group_ids_batch_.clear();
+    group_ids_batch_.resize(batch.length);
+    for (int64_t i = 0; i < batch.length; ++i) {
+      int32_t key_length = offsets_batch_[i + 1] - offsets_batch_[i];
+      std::string key(

Review comment:
       ARROW-12010 will be used to track improvement and generalization of the 
hash table




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to