icexelloss commented on code in PR #34912:
URL: https://github.com/apache/arrow/pull/34912#discussion_r1170588049


##########
cpp/src/arrow/compute/kernels/hash_aggregate.cc:
##########
@@ -1695,6 +1692,389 @@ struct GroupedMinMaxFactory {
   InputType argument_type;
 };
 
+// ----------------------------------------------------------------------
+// FirstLast implementation
+
+template <typename CType>
+struct NullSentinel {
+  static constexpr CType value() { return std::numeric_limits<CType>::min(); }
+};
+
+template <>
+struct NullSentinel<float> {
+  static constexpr float value() { return 
std::numeric_limits<float>::infinity(); }
+};
+
+template <>
+struct NullSentinel<double> {
+  static constexpr double value() { return 
std::numeric_limits<double>::infinity(); }
+};
+
+template <typename Type, typename Enable = void>
+struct GroupedFirstLastImpl final : public GroupedAggregator {
+  using CType = typename TypeTraits<Type>::CType;
+  using GetSet = GroupedValueTraits<Type>;
+  using ArrType =
+      typename std::conditional<is_boolean_type<Type>::value, uint8_t, 
CType>::type;
+
+  Status Init(ExecContext* ctx, const KernelInitArgs& args) override {
+    options_ = *checked_cast<const ScalarAggregateOptions*>(args.options);
+
+    firsts_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+    lasts_ = TypedBufferBuilder<CType>(ctx->memory_pool());
+    has_values_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+    has_nulls_ = TypedBufferBuilder<bool>(ctx->memory_pool());
+    return Status::OK();
+  }
+
+  Status Resize(int64_t new_num_groups) override {
+    auto added_groups = new_num_groups - num_groups_;
+    num_groups_ = new_num_groups;
+    RETURN_NOT_OK(firsts_.Append(added_groups, NullSentinel<CType>::value()));
+    RETURN_NOT_OK(lasts_.Append(added_groups, NullSentinel<CType>::value()));
+    RETURN_NOT_OK(has_values_.Append(added_groups, false));
+    RETURN_NOT_OK(has_nulls_.Append(added_groups, false));
+    return Status::OK();
+  }
+
+  Status Consume(const ExecSpan& batch) override {
+    auto raw_firsts = firsts_.mutable_data();
+    auto raw_lasts = lasts_.mutable_data();
+    auto raw_has_values = has_values_.mutable_data();
+
+    VisitGroupedValues<Type>(
+        batch,
+        [&](uint32_t g, CType val) {
+          if (!bit_util::GetBit(raw_has_values, g)) {
+            GetSet::Set(raw_firsts, g, val);
+            bit_util::SetBit(raw_has_values, g);
+          }
+          GetSet::Set(raw_lasts, g, val);
+          DCHECK(bit_util::GetBit(has_values_.mutable_data(), g));
+        },
+        [&](uint32_t g) { bit_util::SetBit(has_nulls_.mutable_data(), g); });
+    return Status::OK();
+  }
+
+  Status Merge(GroupedAggregator&& raw_other,
+               const ArrayData& group_id_mapping) override {
+    // The merge is asymmetric. "first" from this state gets pick over "first" 
from other
+    // state. "last" from other state gets pick over from this state. This is 
so that when
+    // using with segmeneted aggregation, we still get the correct "first" and 
"last"
+    // value for the entire segement.
+    auto other = checked_cast<GroupedFirstLastImpl*>(&raw_other);
+
+    auto raw_firsts = firsts_.mutable_data();
+    auto raw_lasts = lasts_.mutable_data();
+    auto raw_has_values = has_values_.mutable_data();
+    auto raw_has_nulls = has_nulls_.mutable_data();
+
+    auto other_raw_firsts = other->firsts_.mutable_data();
+    auto other_raw_lasts = other->lasts_.mutable_data();
+    auto other_raw_has_values = other->has_values_.mutable_data();
+    auto other_raw_has_nulls = other->has_nulls_.mutable_data();
+
+    auto g = group_id_mapping.GetValues<uint32_t>(1);
+
+    for (uint32_t other_g = 0; static_cast<int64_t>(other_g) < 
group_id_mapping.length;
+         ++other_g, ++g) {
+      if (!bit_util::GetBit(raw_has_values, *g)) {
+        if (bit_util::GetBit(other_raw_has_values, other_g)) {
+          GetSet::Set(raw_firsts, *g, GetSet::Get(other_raw_firsts, other_g));
+        }
+      }
+
+      if (bit_util::GetBit(other_raw_has_values, other_g)) {
+        GetSet::Set(raw_lasts, *g, GetSet::Get(other_raw_lasts, other_g));
+      }
+
+      if (bit_util::GetBit(other_raw_has_values, other_g)) {
+        bit_util::SetBit(raw_has_values, *g);
+      }
+      if (bit_util::GetBit(other_raw_has_nulls, other_g)) {
+        bit_util::SetBit(raw_has_nulls, *g);
+      }
+    }
+    return Status::OK();
+  }
+
+  Result<Datum> Finalize() override {
+    ARROW_ASSIGN_OR_RAISE(auto null_bitmap, has_values_.Finish());
+
+    if (!options_.skip_nulls) {
+      return Status::NotImplemented("Don't support first/last with skip nulls 
= False");
+    }
+
+    auto firsts = ArrayData::Make(type_, num_groups_, {null_bitmap, nullptr});
+    auto lasts = ArrayData::Make(type_, num_groups_, {std::move(null_bitmap), 
nullptr});
+    ARROW_ASSIGN_OR_RAISE(firsts->buffers[1], firsts_.Finish());
+    ARROW_ASSIGN_OR_RAISE(lasts->buffers[1], lasts_.Finish());
+
+    return ArrayData::Make(out_type(), num_groups_, {nullptr},
+                           {std::move(firsts), std::move(lasts)});
+  }
+
+  std::shared_ptr<DataType> out_type() const override {
+    return struct_({field("first", type_), field("last", type_)});
+  }
+
+  int64_t num_groups_;
+  TypedBufferBuilder<CType> firsts_, lasts_;
+  TypedBufferBuilder<bool> has_values_, has_nulls_;
+  std::shared_ptr<DataType> type_;
+  ScalarAggregateOptions options_;
+};
+
+template <typename Type>
+struct GroupedFirstLastImpl<Type,

Review Comment:
   This class mostly follows GroupedMinMaxImpl with slight modification of the 
Consume and Merge logic



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to