This is an automated email from the ASF dual-hosted git repository.
panxiaolei pushed a commit to branch ckb2
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/ckb2 by this push:
new df47cb8c323 [opt](agg) Optimize the execution of GROUP BY count(*)
(#61310)
df47cb8c323 is described below
commit df47cb8c323e96ddf57210b4daf741fa8c05bc43
Author: Mryange <[email protected]>
AuthorDate: Fri Mar 13 16:18:12 2026 +0800
[opt](agg) Optimize the execution of GROUP BY count(*) (#61310)
### What problem does this PR solve?
```
MySQL [hits]> SELECT ClientIP,sum(ClientIP) AS c
-> FROM hits
-> GROUP BY ClientIP
-> ORDER BY c DESC
-> LIMIT 10;
10 rows in set (0.374 sec)
MySQL [hits]> SELECT ClientIP, COUNT(*) AS c
-> FROM hits
-> GROUP BY ClientIP
-> ORDER BY c DESC
-> LIMIT 10;
10 rows in set (0.312 sec)
```
None
### Check List (For Author)
- Test <!-- At least one of them must be included. -->
- [ ] Regression test
- [ ] Unit Test
- [ ] Manual test (add detailed scripts or steps below)
- [ ] No need to test or manual test. Explain why:
- [ ] This is a refactor/code format and no logic has been changed.
- [ ] Previous test can cover this change.
- [ ] No code files have been changed.
- [ ] Other reason <!-- Add your reason? -->
- Behavior changed:
- [ ] No.
- [ ] Yes. <!-- Explain the behavior change -->
- Does this need documentation?
- [ ] No.
- [ ] Yes. <!-- Add document PR link here. eg:
https://github.com/apache/doris-website/pull/1214 -->
### Check List (For Reviewer who merge this PR)
- [ ] Confirm the release note
- [ ] Confirm test cases
- [ ] Confirm document
- [ ] Add branch pick label <!-- Add branch pick label that this PR
should merge into -->
---
be/src/exec/common/hash_table/hash_map_context.h | 58 +-
be/src/exec/common/hash_table/ph_hash_map.h | 6 +
be/src/exec/common/hash_table/string_hash_map.h | 25 +-
be/src/exec/common/hash_table/string_hash_table.h | 14 +-
be/src/exec/operator/aggregation_sink_operator.cpp | 125 +++-
be/src/exec/operator/aggregation_sink_operator.h | 4 +
.../exec/operator/aggregation_source_operator.cpp | 117 ++++
be/src/exec/operator/set_probe_sink_operator.cpp | 39 ++
be/src/exec/operator/set_source_operator.cpp | 16 +-
be/src/exec/operator/set_source_operator.h | 1 +
.../operator/streaming_aggregation_operator.cpp | 202 +++++-
.../exec/operator/streaming_aggregation_operator.h | 7 +
be/src/exec/pipeline/dependency.h | 6 +
be/src/exprs/aggregate/aggregate_function.h | 2 +
be/src/exprs/aggregate/aggregate_function_count.h | 1 +
be/test/exec/hash_map/hash_table_method_test.cpp | 712 ++++++++++++++++++++-
16 files changed, 1233 insertions(+), 102 deletions(-)
diff --git a/be/src/exec/common/hash_table/hash_map_context.h
b/be/src/exec/common/hash_table/hash_map_context.h
index b036cb95590..8b7f0d6b48a 100644
--- a/be/src/exec/common/hash_table/hash_map_context.h
+++ b/be/src/exec/common/hash_table/hash_map_context.h
@@ -1126,7 +1126,7 @@ struct MethodKeysFixed : public MethodBase<TData> {
};
template <typename Base>
-struct DataWithNullKeyImpl : public Base {
+struct DataWithNullKey : public Base {
bool& has_null_key_data() { return has_null_key; }
bool has_null_key_data() const { return has_null_key; }
template <typename MappedType>
@@ -1151,62 +1151,6 @@ protected:
Base::Value null_key_data;
};
-template <typename Base>
-struct DataWithNullKey : public DataWithNullKeyImpl<Base> {};
-
-template <IteratoredMap Base>
-struct DataWithNullKey<Base> : public DataWithNullKeyImpl<Base> {
- using DataWithNullKeyImpl<Base>::null_key_data;
- using DataWithNullKeyImpl<Base>::has_null_key;
-
- struct Iterator {
- typename Base::iterator base_iterator = {};
- bool current_null = false;
- Base::Value* null_key_data = nullptr;
-
- Iterator() = default;
- Iterator(typename Base::iterator it, bool null, Base::Value* null_key)
- : base_iterator(it), current_null(null),
null_key_data(null_key) {}
- bool operator==(const Iterator& rhs) const {
- return current_null == rhs.current_null && base_iterator ==
rhs.base_iterator;
- }
-
- bool operator!=(const Iterator& rhs) const { return !(*this == rhs); }
-
- Iterator& operator++() {
- if (current_null) {
- current_null = false;
- } else {
- ++base_iterator;
- }
- return *this;
- }
-
- Base::Value& get_second() {
- if (current_null) {
- return *null_key_data;
- } else {
- return base_iterator->get_second();
- }
- }
- };
-
- Iterator begin() { return {Base::begin(), has_null_key, &null_key_data}; }
-
- Iterator end() { return {Base::end(), false, &null_key_data}; }
-
- void insert(const Iterator& other_iter) {
- if (other_iter.current_null) {
- has_null_key = true;
- null_key_data = *other_iter.null_key_data;
- } else {
- Base::insert(other_iter.base_iterator);
- }
- }
-
- using iterator = Iterator;
-};
-
/// Single low cardinality column.
template <typename SingleColumnMethod>
struct MethodSingleNullableColumn : public SingleColumnMethod {
diff --git a/be/src/exec/common/hash_table/ph_hash_map.h
b/be/src/exec/common/hash_table/ph_hash_map.h
index 92a6b5f9557..212b59609c9 100644
--- a/be/src/exec/common/hash_table/ph_hash_map.h
+++ b/be/src/exec/common/hash_table/ph_hash_map.h
@@ -188,6 +188,12 @@ public:
for (auto& v : *this) func(v.get_second());
}
+ /// Call func(const Key &, Mapped &) for each hash map element.
+ template <typename Func>
+ void for_each(Func&& func) {
+ for (auto& v : *this) func(v.get_first(), v.get_second());
+ }
+
size_t get_buffer_size_in_bytes() const {
const auto capacity = _hash_map.capacity();
return capacity * sizeof(typename HashMapImpl::slot_type);
diff --git a/be/src/exec/common/hash_table/string_hash_map.h
b/be/src/exec/common/hash_table/string_hash_map.h
index c615e6b5dcc..1c1dfce50dd 100644
--- a/be/src/exec/common/hash_table/string_hash_map.h
+++ b/be/src/exec/common/hash_table/string_hash_map.h
@@ -68,7 +68,7 @@ struct StringHashMapCell<doris::StringRef, TMapped>
using Base::Base;
static constexpr bool need_zero_value_storage = false;
// external
- using Base::get_key;
+ const doris::StringRef& get_key() const { return this->value.first; } ///
NOLINT
// internal
static const doris::StringRef& get_key(const value_type& value_) { return
value_.first; }
@@ -150,6 +150,29 @@ public:
func(v.get_second());
}
}
+
+ template <typename Func>
+ void for_each(Func&& func) {
+ if (this->m0.size()) {
+ func(this->m0.zero_value()->get_key(),
this->m0.zero_value()->get_second());
+ }
+ for (auto& v : this->m1) {
+ func(v.get_key(), v.get_second());
+ }
+ for (auto& v : this->m2) {
+ func(v.get_key(), v.get_second());
+ }
+ for (auto& v : this->m3) {
+ func(v.get_key(), v.get_second());
+ }
+ for (auto& v : this->m4) {
+ func(v.get_key(), v.get_second());
+ }
+ for (auto& v : this->ms) {
+ func(v.get_key(), v.get_second());
+ }
+ }
+
template <typename MappedType>
char* get_null_key_data() {
return nullptr;
diff --git a/be/src/exec/common/hash_table/string_hash_table.h
b/be/src/exec/common/hash_table/string_hash_table.h
index 007c2cf8fc9..5372f919739 100644
--- a/be/src/exec/common/hash_table/string_hash_table.h
+++ b/be/src/exec/common/hash_table/string_hash_table.h
@@ -50,7 +50,9 @@ StringKey to_string_key(const doris::StringRef& key) {
template <typename T>
inline doris::StringRef ALWAYS_INLINE to_string_ref(const T& n) {
assert(n != 0);
- return {reinterpret_cast<const char*>(&n), sizeof(T) - (__builtin_clzll(n)
>> 3)};
+ // __builtin_clzll counts leading zero bits in a 64-bit (8-byte) value,
+ // so we must use 8 here instead of sizeof(T) to get the correct byte
count.
+ return {reinterpret_cast<const char*>(&n), static_cast<size_t>(8 -
(__builtin_clzll(n) >> 3))};
}
inline doris::StringRef ALWAYS_INLINE to_string_ref(const StringKey16& n) {
assert(n.items[1] != 0);
@@ -415,7 +417,7 @@ protected:
return static_cast<Derived&>(*this);
}
- auto& operator*() const {
+ auto& operator*() {
switch (sub_table_index) {
case 0: {
this->cell = *(container->m0.zero_value());
@@ -444,9 +446,13 @@ protected:
}
return cell;
}
- auto* operator->() const { return &(this->operator*()); }
+ auto* operator->() { return &(this->operator*()); }
- auto get_ptr() const { return &(this->operator*()); }
+ auto get_ptr() { return &(this->operator*()); }
+
+ // Provide get_first()/get_second() at the iterator level, consistent
with PHHashMap::iterator
+ auto& get_first() { return (**this).get_first(); }
+ auto& get_second() { return (**this).get_second(); }
size_t get_hash() const {
switch (sub_table_index) {
diff --git a/be/src/exec/operator/aggregation_sink_operator.cpp
b/be/src/exec/operator/aggregation_sink_operator.cpp
index 133b91f4f1a..a8d95f58436 100644
--- a/be/src/exec/operator/aggregation_sink_operator.cpp
+++ b/be/src/exec/operator/aggregation_sink_operator.cpp
@@ -25,6 +25,7 @@
#include "core/data_type/primitive_type.h"
#include "exec/common/hash_table/hash.h"
#include "exec/operator/operator.h"
+#include "exprs/aggregate/aggregate_function_count.h"
#include "exprs/aggregate/aggregate_function_simple_factory.h"
#include "exprs/vectorized_agg_fn.h"
#include "runtime/runtime_profile.h"
@@ -156,6 +157,30 @@ Status AggSinkLocalState::open(RuntimeState* state) {
RETURN_IF_ERROR(_create_agg_status(_agg_data->without_key));
_shared_state->agg_data_created_without_key = true;
}
+
+ // Determine whether to use simple count aggregation.
+ // For queries like: SELECT xxx, count(*) / count(not_null_column) FROM
table GROUP BY xxx,
+ // count(*) / count(not_null_column) can store a uint64 counter directly
in the hash table,
+ // instead of storing the full aggregate state, saving memory and
computation overhead.
+ // Requirements:
+ // 0. The aggregation has a GROUP BY clause.
+ // 1. There is exactly one count aggregate function.
+ // 2. No limit optimization is applied.
+ // 3. Spill is not enabled (the spill path accesses
aggregate_data_container, which is empty in inline count mode).
+ // Supports update / merge / finalize / serialize phases, since count's
serialization format is UInt64 itself.
+
+ if (!Base::_shared_state->probe_expr_ctxs.empty() /* has GROUP BY */
+ && (p._aggregate_evaluators.size() == 1 &&
+ p._aggregate_evaluators[0]->function()->is_simple_count()) /* only
one count(*) */
+ && !_should_limit_output /* no limit optimization */ &&
+ !Base::_shared_state->enable_spill /* spill not enabled */) {
+ _shared_state->use_simple_count = true;
+#ifndef NDEBUG
+ // Randomly enable/disable in debug mode to verify correctness of
multi-phase agg promotion/demotion.
+ _shared_state->use_simple_count = rand() % 2 == 0;
+#endif
+ }
+
return Status::OK();
}
@@ -335,7 +360,18 @@ Status
AggSinkLocalState::_merge_with_serialized_key_helper(Block* block) {
key_columns,
(uint32_t)rows);
rows = block->rows();
} else {
- _emplace_into_hash_table(_places.data(), key_columns,
(uint32_t)rows);
+ if (_shared_state->use_simple_count) {
+ DCHECK(!for_spill);
+
+ auto col_id = AggSharedState::get_slot_column_id(
+ Base::_shared_state->aggregate_evaluators[0]);
+
+ auto column = block->get_by_position(col_id).column;
+ _merge_into_hash_table_inline_count(key_columns, column.get(),
(uint32_t)rows);
+ need_do_agg = false;
+ } else {
+ _emplace_into_hash_table(_places.data(), key_columns,
(uint32_t)rows);
+ }
}
if (need_do_agg) {
@@ -496,7 +532,9 @@ Status
AggSinkLocalState::_execute_with_serialized_key_helper(Block* block) {
}
} else {
_emplace_into_hash_table(_places.data(), key_columns, rows);
- RETURN_IF_ERROR(do_aggregate_evaluators());
+ if (!_shared_state->use_simple_count) {
+ RETURN_IF_ERROR(do_aggregate_evaluators());
+ }
if (_should_limit_output && !Base::_shared_state->enable_spill) {
const size_t hash_table_size = _get_hash_table_size();
@@ -524,6 +562,11 @@ size_t AggSinkLocalState::_get_hash_table_size() const {
void AggSinkLocalState::_emplace_into_hash_table(AggregateDataPtr* places,
ColumnRawPtrs& key_columns,
uint32_t num_rows) {
+ if (_shared_state->use_simple_count) {
+ _emplace_into_hash_table_inline_count(key_columns, num_rows);
+ return;
+ }
+
std::visit(Overload {[&](std::monostate& arg) -> void {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
@@ -570,6 +613,84 @@ void
AggSinkLocalState::_emplace_into_hash_table(AggregateDataPtr* places,
_agg_data->method_variant);
}
+// For the agg hashmap<key, value>, the value is a char* type which is exactly
64 bits.
+// Here we treat it as a uint64 counter: each time the same key is
encountered, the counter
+// is incremented by 1. This avoids storing the full aggregate state, saving
memory and computation overhead.
+void AggSinkLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs&
key_columns,
+ uint32_t
num_rows) {
+ std::visit(Overload {[&](std::monostate& arg) -> void {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
+ "uninited hash table");
+ },
+ [&](auto& agg_method) -> void {
+ SCOPED_TIMER(_hash_table_compute_timer);
+ using HashMethodType =
std::decay_t<decltype(agg_method)>;
+ using AggState = typename HashMethodType::State;
+ AggState state(key_columns);
+ agg_method.init_serialized_keys(key_columns,
num_rows);
+
+ auto creator = [&](const auto& ctor, auto& key,
auto& origin) {
+ HashMethodType::try_presis_key_and_origin(
+ key, origin,
Base::_shared_state->agg_arena_pool);
+ AggregateDataPtr mapped = nullptr;
+ ctor(key, mapped);
+ };
+
+ auto creator_for_null_key = [&](auto& mapped) {
mapped = nullptr; };
+
+ SCOPED_TIMER(_hash_table_emplace_timer);
+ for (size_t i = 0; i < num_rows; ++i) {
+ auto* mapped_ptr =
agg_method.lazy_emplace(state, i, creator,
+
creator_for_null_key);
+ ++reinterpret_cast<UInt64&>(*mapped_ptr);
+ }
+
+ COUNTER_UPDATE(_hash_table_input_counter,
num_rows);
+ }},
+ _agg_data->method_variant);
+}
+
+void AggSinkLocalState::_merge_into_hash_table_inline_count(ColumnRawPtrs&
key_columns,
+ const IColumn*
merge_column,
+ uint32_t num_rows)
{
+ std::visit(Overload {[&](std::monostate& arg) -> void {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
+ "uninited hash table");
+ },
+ [&](auto& agg_method) -> void {
+ SCOPED_TIMER(_hash_table_compute_timer);
+ using HashMethodType =
std::decay_t<decltype(agg_method)>;
+ using AggState = typename HashMethodType::State;
+ AggState state(key_columns);
+ agg_method.init_serialized_keys(key_columns,
num_rows);
+
+ const auto& col =
+ assert_cast<const
ColumnFixedLengthObject&>(*merge_column);
+ const auto* col_data =
+ reinterpret_cast<const
AggregateFunctionCountData*>(
+ col.get_data().data());
+
+ auto creator = [&](const auto& ctor, auto& key,
auto& origin) {
+ HashMethodType::try_presis_key_and_origin(
+ key, origin,
Base::_shared_state->agg_arena_pool);
+ AggregateDataPtr mapped = nullptr;
+ ctor(key, mapped);
+ };
+
+ auto creator_for_null_key = [&](auto& mapped) {
mapped = nullptr; };
+
+ SCOPED_TIMER(_hash_table_emplace_timer);
+ for (size_t i = 0; i < num_rows; ++i) {
+ auto* mapped_ptr =
agg_method.lazy_emplace(state, i, creator,
+
creator_for_null_key);
+ reinterpret_cast<UInt64&>(*mapped_ptr) +=
col_data[i].count;
+ }
+
+ COUNTER_UPDATE(_hash_table_input_counter,
num_rows);
+ }},
+ _agg_data->method_variant);
+}
+
bool AggSinkLocalState::_emplace_into_hash_table_limit(AggregateDataPtr*
places, Block* block,
const std::vector<int>&
key_locs,
ColumnRawPtrs&
key_columns,
diff --git a/be/src/exec/operator/aggregation_sink_operator.h
b/be/src/exec/operator/aggregation_sink_operator.h
index f90671cb465..76e33d8bd1e 100644
--- a/be/src/exec/operator/aggregation_sink_operator.h
+++ b/be/src/exec/operator/aggregation_sink_operator.h
@@ -86,6 +86,10 @@ protected:
uint32_t num_rows);
void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs&
key_columns,
uint32_t num_rows);
+
+ void _emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns,
uint32_t num_rows);
+ void _merge_into_hash_table_inline_count(ColumnRawPtrs& key_columns,
+ const IColumn* merge_column,
uint32_t num_rows);
bool _emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block,
const std::vector<int>& key_locs,
ColumnRawPtrs& key_columns, uint32_t
num_rows);
diff --git a/be/src/exec/operator/aggregation_source_operator.cpp
b/be/src/exec/operator/aggregation_source_operator.cpp
index 9c6041e4dc4..06463b63e6e 100644
--- a/be/src/exec/operator/aggregation_source_operator.cpp
+++ b/be/src/exec/operator/aggregation_source_operator.cpp
@@ -21,6 +21,7 @@
#include <string>
#include "common/exception.h"
+#include "core/column/column_fixed_length_object.h"
#include "exec/operator/operator.h"
#include "exprs/vectorized_agg_fn.h"
#include "exprs/vexpr_fwd.h"
@@ -131,6 +132,76 @@ Status
AggLocalState::_get_results_with_serialized_key(RuntimeState* state, Bloc
const auto size = std::min(data.size(),
size_t(state->batch_size()));
using KeyType =
std::decay_t<decltype(agg_method)>::Key;
std::vector<KeyType> keys(size);
+
+ if (shared_state.use_simple_count) {
+
DCHECK_EQ(shared_state.aggregate_evaluators.size(), 1);
+
+ value_data_types[0] =
shared_state.aggregate_evaluators[0]
+ ->function()
+
->get_serialized_type();
+ if (mem_reuse) {
+ value_columns[0] =
+
std::move(*block->get_by_position(key_size).column)
+ .mutate();
+ } else {
+ value_columns[0] =
shared_state.aggregate_evaluators[0]
+ ->function()
+
->create_serialize_column();
+ }
+
+ std::vector<UInt64> inline_counts(size);
+ uint32_t num_rows = 0;
+ {
+ SCOPED_TIMER(_hash_table_iterate_timer);
+ auto& it = agg_method.begin;
+ while (it != agg_method.end && num_rows <
state->batch_size()) {
+ keys[num_rows] = it.get_first();
+ inline_counts[num_rows] =
+ reinterpret_cast<const
UInt64&>(it.get_second());
+ ++it;
+ ++num_rows;
+ }
+ }
+
+ {
+ SCOPED_TIMER(_insert_keys_to_column_timer);
+ agg_method.insert_keys_into_columns(keys,
key_columns, num_rows);
+ }
+
+ // Write inline counts to serialized column
+ // AggregateFunctionCountData = { UInt64 count },
same layout as inline
+ auto& count_col =
+
assert_cast<ColumnFixedLengthObject&>(*value_columns[0]);
+ count_col.resize(num_rows);
+ auto* col_data = count_col.get_data().data();
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ *reinterpret_cast<UInt64*>(col_data + i *
sizeof(UInt64)) =
+ inline_counts[i];
+ }
+
+ // Handle null key if present
+ if (agg_method.begin == agg_method.end) {
+ if
(agg_method.hash_table->has_null_key_data()) {
+ DCHECK(key_columns.size() == 1);
+ DCHECK(key_columns[0]->is_nullable());
+ if (num_rows < state->batch_size()) {
+ key_columns[0]->insert_data(nullptr,
0);
+ auto mapped =
+
agg_method.hash_table->template get_null_key_data<
+ AggregateDataPtr>();
+ count_col.resize(num_rows + 1);
+
*reinterpret_cast<UInt64*>(count_col.get_data().data() +
+ num_rows *
sizeof(UInt64)) =
+ std::bit_cast<UInt64>(mapped);
+ *eos = true;
+ }
+ } else {
+ *eos = true;
+ }
+ }
+ return;
+ }
+
if (shared_state.values.size() < size + 1) {
shared_state.values.resize(size + 1);
}
@@ -255,6 +326,52 @@ Status
AggLocalState::_get_with_serialized_key_result(RuntimeState* state, Block
const auto size = std::min(data.size(),
size_t(state->batch_size()));
using KeyType =
std::decay_t<decltype(agg_method)>::Key;
std::vector<KeyType> keys(size);
+
+ if (shared_state.use_simple_count) {
+ // Inline count: mapped slot stores UInt64 count
directly
+ // (not a real AggregateDataPtr). Iterate hash
table directly.
+ DCHECK_EQ(value_columns.size(), 1);
+ auto& count_column =
assert_cast<ColumnInt64&>(*value_columns[0]);
+ uint32_t num_rows = 0;
+ {
+ SCOPED_TIMER(_hash_table_iterate_timer);
+ auto& it = agg_method.begin;
+ while (it != agg_method.end && num_rows <
state->batch_size()) {
+ keys[num_rows] = it.get_first();
+ auto& mapped = it.get_second();
+
count_column.insert_value(static_cast<Int64>(
+ reinterpret_cast<const
UInt64&>(mapped)));
+ ++it;
+ ++num_rows;
+ }
+ }
+ {
+ SCOPED_TIMER(_insert_keys_to_column_timer);
+ agg_method.insert_keys_into_columns(keys,
key_columns, num_rows);
+ }
+
+ // Handle null key if present
+ if (agg_method.begin == agg_method.end) {
+ if
(agg_method.hash_table->has_null_key_data()) {
+ DCHECK(key_columns.size() == 1);
+ DCHECK(key_columns[0]->is_nullable());
+ if (key_columns[0]->size() <
state->batch_size()) {
+ key_columns[0]->insert_data(nullptr,
0);
+ auto mapped =
+
agg_method.hash_table->template get_null_key_data<
+ AggregateDataPtr>();
+ count_column.insert_value(
+
static_cast<Int64>(std::bit_cast<UInt64>(mapped)));
+ *eos = true;
+ }
+ } else {
+ *eos = true;
+ }
+ }
+ return;
+ }
+
+ // Normal (non-simple-count) path
if (shared_state.values.size() < size) {
shared_state.values.resize(size);
}
diff --git a/be/src/exec/operator/set_probe_sink_operator.cpp
b/be/src/exec/operator/set_probe_sink_operator.cpp
index 26913e97641..32f61f73308 100644
--- a/be/src/exec/operator/set_probe_sink_operator.cpp
+++ b/be/src/exec/operator/set_probe_sink_operator.cpp
@@ -230,6 +230,35 @@ void
SetProbeSinkOperatorX<is_intersect>::_refresh_hash_table(
std::make_shared<typename
HashTableCtxType::HashMapType>();
tmp_hash_table->reserve(
local_state._shared_state->valid_element_in_hash_tbl);
+
+ // Handle null key separately since iterator does not
cover it
+ using NullMappedType =
+ std::decay_t<decltype(arg.hash_table->template
get_null_key_data<
+ RowRefWithFlag>())>;
+ if constexpr (std::is_same_v<NullMappedType,
RowRefWithFlag>) {
+ if (arg.hash_table->has_null_key_data()) {
+ auto& null_mapped =
+ arg.hash_table
+ ->template
get_null_key_data<RowRefWithFlag>();
+ if constexpr (is_intersect) {
+ if (null_mapped.visited) {
+ null_mapped.visited = false;
+ tmp_hash_table->has_null_key_data() =
true;
+ tmp_hash_table
+ ->template
get_null_key_data<RowRefWithFlag>() =
+ null_mapped;
+ }
+ } else {
+ if (!null_mapped.visited) {
+ tmp_hash_table->has_null_key_data() =
true;
+ tmp_hash_table
+ ->template
get_null_key_data<RowRefWithFlag>() =
+ null_mapped;
+ }
+ }
+ }
+ }
+
while (iter != arg.end) {
auto& mapped = iter.get_second();
auto* it = &mapped;
@@ -249,6 +278,16 @@ void
SetProbeSinkOperatorX<is_intersect>::_refresh_hash_table(
arg.hash_table = std::move(tmp_hash_table);
} else if (is_intersect) {
DCHECK_EQ(valid_element_in_hash_tbl,
arg.hash_table->size());
+ // Reset null key's visited flag separately
+ using NullMappedType =
+ std::decay_t<decltype(arg.hash_table->template
get_null_key_data<
+ RowRefWithFlag>())>;
+ if constexpr (std::is_same_v<NullMappedType,
RowRefWithFlag>) {
+ if (arg.hash_table->has_null_key_data()) {
+ arg.hash_table->template
get_null_key_data<RowRefWithFlag>()
+ .visited = false;
+ }
+ }
while (iter != arg.end) {
auto& mapped = iter.get_second();
auto* it = &mapped;
diff --git a/be/src/exec/operator/set_source_operator.cpp
b/be/src/exec/operator/set_source_operator.cpp
index cd1ac4bc45a..55defd9dacc 100644
--- a/be/src/exec/operator/set_source_operator.cpp
+++ b/be/src/exec/operator/set_source_operator.cpp
@@ -145,13 +145,25 @@ Status
SetSourceOperatorX<is_intersect>::_get_data_in_hashtable(
}
};
+ // Output null key first (if present and not yet output)
+ if (!local_state._null_key_output &&
hash_table_ctx.hash_table->has_null_key_data()) {
+ auto value = hash_table_ctx.hash_table->template
get_null_key_data<RowRefWithFlag>();
+ static_assert(std::is_same_v<RowRefWithFlag,
std::decay_t<decltype(value)>> ||
+ std::is_same_v<char*, std::decay_t<decltype(value)>>);
+ if constexpr (std::is_same_v<RowRefWithFlag,
std::decay_t<decltype(value)>>) {
+ add_result(value);
+ }
+ local_state._null_key_output = true;
+ }
+
auto& iter = hash_table_ctx.begin;
- while (iter != hash_table_ctx.end && local_state._result_indexs.size() <
batch_size) {
+ while (iter != hash_table_ctx.hash_table->end() &&
+ local_state._result_indexs.size() < batch_size) {
add_result(iter.get_second());
++iter;
}
- *eos = iter == hash_table_ctx.end;
+ *eos = iter == hash_table_ctx.hash_table->end();
COUNTER_UPDATE(local_state._get_data_from_hashtable_rows,
local_state._result_indexs.size());
local_state._add_result_columns();
diff --git a/be/src/exec/operator/set_source_operator.h
b/be/src/exec/operator/set_source_operator.h
index 31e5fc77542..f2f245f1edc 100644
--- a/be/src/exec/operator/set_source_operator.h
+++ b/be/src/exec/operator/set_source_operator.h
@@ -51,6 +51,7 @@ private:
RuntimeProfile::Counter* _filter_timer = nullptr;
RuntimeProfile::Counter* _get_data_from_hashtable_rows = nullptr;
IColumn::Selector _result_indexs;
+ bool _null_key_output = false;
};
template <bool is_intersect>
diff --git a/be/src/exec/operator/streaming_aggregation_operator.cpp
b/be/src/exec/operator/streaming_aggregation_operator.cpp
index 4458415b1a5..8ce01aef955 100644
--- a/be/src/exec/operator/streaming_aggregation_operator.cpp
+++ b/be/src/exec/operator/streaming_aggregation_operator.cpp
@@ -24,7 +24,9 @@
#include "common/cast_set.h"
#include "common/compiler_util.h" // IWYU pragma: keep
+#include "core/column/column_fixed_length_object.h"
#include "exec/operator/operator.h"
+#include "exprs/aggregate/aggregate_function_count.h"
#include "exprs/aggregate/aggregate_function_simple_factory.h"
#include "exprs/vectorized_agg_fn.h"
#include "exprs/vslot_ref.h"
@@ -149,22 +151,36 @@ Status StreamingAggLocalState::open(RuntimeState* state) {
RETURN_IF_ERROR(_init_hash_method(_probe_expr_ctxs));
- std::visit(Overload {[&](std::monostate& arg) -> void {
- throw doris::Exception(ErrorCode::INTERNAL_ERROR,
- "uninited hash table");
- },
- [&](auto& agg_method) {
- using HashTableType =
std::decay_t<decltype(agg_method)>;
- using KeyType = typename HashTableType::Key;
-
- /// some aggregate functions (like AVG for
decimal) have align issues.
- _aggregate_data_container =
std::make_unique<AggregateDataContainer>(
- sizeof(KeyType),
((p._total_size_of_aggregate_states +
-
p._align_aggregate_states - 1) /
-
p._align_aggregate_states) *
-
p._align_aggregate_states);
- }},
- _agg_data->method_variant);
+ // Determine whether to use simple count aggregation.
+ // StreamingAgg only operates in update + serialize mode: input is raw
data, output is serialized intermediate state.
+ // The serialization format of count is UInt64 itself, so it can be
inlined into the hash table mapped slot.
+ if (_aggregate_evaluators.size() == 1 &&
+ _aggregate_evaluators[0]->function()->is_simple_count()) {
+ _use_simple_count = true;
+#ifndef NDEBUG
+ // Randomly enable/disable in debug mode to verify correctness of
multi-phase agg promotion/demotion.
+ _use_simple_count = rand() % 2 == 0;
+#endif
+ }
+
+ std::visit(
+ Overload {[&](std::monostate& arg) -> void {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ },
+ [&](auto& agg_method) {
+ using HashTableType =
std::decay_t<decltype(agg_method)>;
+ using KeyType = typename HashTableType::Key;
+
+ if (!_use_simple_count) {
+ /// some aggregate functions (like AVG for
decimal) have align issues.
+ _aggregate_data_container =
std::make_unique<AggregateDataContainer>(
+ sizeof(KeyType),
((p._total_size_of_aggregate_states +
+
p._align_aggregate_states - 1) /
+
p._align_aggregate_states) *
+
p._align_aggregate_states);
+ }
+ }},
+ _agg_data->method_variant);
limit = p._sort_limit;
do_sort_limit = p._do_sort_limit;
@@ -191,8 +207,11 @@ void
StreamingAggLocalState::_update_memusage_with_serialized_key() {
},
[&](auto& agg_method) -> void {
auto& data = *agg_method.hash_table;
- int64_t arena_memory_usage =
_agg_arena_pool.size() +
-
_aggregate_data_container->memory_usage();
+ int64_t arena_memory_usage =
+ _agg_arena_pool.size() +
+ (_aggregate_data_container
+ ?
_aggregate_data_container->memory_usage()
+ : 0);
int64_t hash_table_memory_usage =
data.get_buffer_size_in_bytes();
COUNTER_SET(_memory_used_counter,
@@ -325,22 +344,23 @@ bool
StreamingAggLocalState::_should_not_do_pre_agg(size_t rows) {
const auto spill_streaming_agg_mem_limit =
p._spill_streaming_agg_mem_limit;
const bool used_too_much_memory =
spill_streaming_agg_mem_limit > 0 && _memory_usage() >
spill_streaming_agg_mem_limit;
- std::visit(Overload {[&](std::monostate& arg) {
- throw doris::Exception(ErrorCode::INTERNAL_ERROR,
- "uninited hash table");
- },
- [&](auto& agg_method) {
- auto& hash_tbl = *agg_method.hash_table;
- /// If too much memory is used during the
pre-aggregation stage,
- /// it is better to output the data directly
without performing further aggregation.
- // do not try to do agg, just init and serialize
directly return the out_block
- if (used_too_much_memory ||
(hash_tbl.add_elem_size_overflow(rows) &&
-
!_should_expand_preagg_hash_tables())) {
- SCOPED_TIMER(_streaming_agg_timer);
- ret_flag = true;
- }
- }},
- _agg_data->method_variant);
+ std::visit(
+ Overload {
+ [&](std::monostate& arg) {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
+ },
+ [&](auto& agg_method) {
+ auto& hash_tbl = *agg_method.hash_table;
+ /// If too much memory is used during the
pre-aggregation stage,
+ /// it is better to output the data directly without
performing further aggregation.
+ // do not try to do agg, just init and serialize
directly return the out_block
+ if (used_too_much_memory ||
(hash_tbl.add_elem_size_overflow(rows) &&
+
!_should_expand_preagg_hash_tables())) {
+ SCOPED_TIMER(_streaming_agg_timer);
+ ret_flag = true;
+ }
+ }},
+ _agg_data->method_variant);
return ret_flag;
}
@@ -438,7 +458,12 @@ Status
StreamingAggLocalState::_pre_agg_with_serialized_key(doris::Block* in_blo
} else {
bool need_agg = true;
if (need_do_sort_limit != 1) {
- _emplace_into_hash_table(_places.data(), key_columns, rows);
+ if (_use_simple_count) {
+ _emplace_into_hash_table_inline_count(key_columns, rows);
+ need_agg = false;
+ } else {
+ _emplace_into_hash_table(_places.data(), key_columns, rows);
+ }
} else {
need_agg = _emplace_into_hash_table_limit(_places.data(),
in_block, key_columns, rows);
}
@@ -506,6 +531,74 @@ Status
StreamingAggLocalState::_get_results_with_serialized_key(RuntimeState* st
const auto size = std::min(data.size(),
size_t(state->batch_size()));
using KeyType =
std::decay_t<decltype(agg_method)>::Key;
std::vector<KeyType> keys(size);
+
+ if (_use_simple_count) {
+ DCHECK_EQ(_aggregate_evaluators.size(), 1);
+
+ value_data_types[0] =
+
_aggregate_evaluators[0]->function()->get_serialized_type();
+ if (mem_reuse) {
+ value_columns[0] =
+
std::move(*block->get_by_position(key_size).column)
+ .mutate();
+ } else {
+ value_columns[0] = _aggregate_evaluators[0]
+ ->function()
+
->create_serialize_column();
+ }
+
+ std::vector<UInt64> inline_counts(size);
+ uint32_t num_rows = 0;
+ {
+ SCOPED_TIMER(_hash_table_iterate_timer);
+ auto& it = agg_method.begin;
+ while (it != agg_method.end && num_rows <
state->batch_size()) {
+ keys[num_rows] = it.get_first();
+ inline_counts[num_rows] =
+ reinterpret_cast<const
UInt64&>(it.get_second());
+ ++it;
+ ++num_rows;
+ }
+ }
+
+ {
+ SCOPED_TIMER(_insert_keys_to_column_timer);
+ agg_method.insert_keys_into_columns(keys,
key_columns, num_rows);
+ }
+
+ // Write inline counts to serialized column
+ auto& count_col =
+
assert_cast<ColumnFixedLengthObject&>(*value_columns[0]);
+ count_col.resize(num_rows);
+ auto* col_data = count_col.get_data().data();
+ for (uint32_t i = 0; i < num_rows; ++i) {
+ *reinterpret_cast<UInt64*>(col_data + i *
sizeof(UInt64)) =
+ inline_counts[i];
+ }
+
+ // Handle null key if present
+ if (agg_method.begin == agg_method.end) {
+ if
(agg_method.hash_table->has_null_key_data()) {
+ DCHECK(key_columns.size() == 1);
+ DCHECK(key_columns[0]->is_nullable());
+ if (num_rows < state->batch_size()) {
+ key_columns[0]->insert_data(nullptr,
0);
+ auto mapped =
+
agg_method.hash_table->template get_null_key_data<
+ AggregateDataPtr>();
+ count_col.resize(num_rows + 1);
+
*reinterpret_cast<UInt64*>(count_col.get_data().data() +
+ num_rows *
sizeof(UInt64)) =
+ std::bit_cast<UInt64>(mapped);
+ *eos = true;
+ }
+ } else {
+ *eos = true;
+ }
+ }
+ return;
+ }
+
if (_values.size() < size + 1) {
_values.resize(size + 1);
}
@@ -778,6 +871,11 @@ bool StreamingAggLocalState::_do_limit_filter(size_t
num_rows, ColumnRawPtrs& ke
void StreamingAggLocalState::_emplace_into_hash_table(AggregateDataPtr* places,
ColumnRawPtrs&
key_columns,
const uint32_t num_rows)
{
+ if (_use_simple_count) {
+ _emplace_into_hash_table_inline_count(key_columns, num_rows);
+ return;
+ }
+
std::visit(Overload {[&](std::monostate& arg) -> void {
throw doris::Exception(ErrorCode::INTERNAL_ERROR,
"uninited hash table");
@@ -822,6 +920,40 @@ void
StreamingAggLocalState::_emplace_into_hash_table(AggregateDataPtr* places,
_agg_data->method_variant);
}
+void
StreamingAggLocalState::_emplace_into_hash_table_inline_count(ColumnRawPtrs&
key_columns,
+ uint32_t
num_rows) {
+ std::visit(Overload {[&](std::monostate& arg) -> void {
+ throw doris::Exception(ErrorCode::INTERNAL_ERROR,
+ "uninited hash table");
+ },
+ [&](auto& agg_method) -> void {
+ SCOPED_TIMER(_hash_table_compute_timer);
+ using HashMethodType =
std::decay_t<decltype(agg_method)>;
+ using AggState = typename HashMethodType::State;
+ AggState state(key_columns);
+ agg_method.init_serialized_keys(key_columns,
num_rows);
+
+ auto creator = [&](const auto& ctor, auto& key,
auto& origin) {
+
HashMethodType::try_presis_key_and_origin(key, origin,
+
_agg_arena_pool);
+ AggregateDataPtr mapped = nullptr;
+ ctor(key, mapped);
+ };
+
+ auto creator_for_null_key = [&](auto& mapped) {
mapped = nullptr; };
+
+ SCOPED_TIMER(_hash_table_emplace_timer);
+ for (size_t i = 0; i < num_rows; ++i) {
+ auto* mapped_ptr =
agg_method.lazy_emplace(state, i, creator,
+
creator_for_null_key);
+ ++reinterpret_cast<UInt64&>(*mapped_ptr);
+ }
+
+ COUNTER_UPDATE(_hash_table_input_counter,
num_rows);
+ }},
+ _agg_data->method_variant);
+}
+
StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id,
const TPlanNode& tnode, const
DescriptorTbl& descs)
: StatefulOperatorX<StreamingAggLocalState>(pool, tnode, operator_id,
descs),
diff --git a/be/src/exec/operator/streaming_aggregation_operator.h
b/be/src/exec/operator/streaming_aggregation_operator.h
index abc905cbfbd..2194a8ad725 100644
--- a/be/src/exec/operator/streaming_aggregation_operator.h
+++ b/be/src/exec/operator/streaming_aggregation_operator.h
@@ -68,6 +68,7 @@ private:
Status _get_results_with_serialized_key(RuntimeState* state, Block* block,
bool* eos);
void _emplace_into_hash_table(AggregateDataPtr* places, ColumnRawPtrs&
key_columns,
const uint32_t num_rows);
+ void _emplace_into_hash_table_inline_count(ColumnRawPtrs& key_columns,
uint32_t num_rows);
bool _emplace_into_hash_table_limit(AggregateDataPtr* places, Block* block,
ColumnRawPtrs& key_columns, uint32_t
num_rows);
Status _create_agg_status(AggregateDataPtr data);
@@ -98,6 +99,7 @@ private:
// group by k1,k2
VExprContextSPtrs _probe_expr_ctxs;
std::unique_ptr<AggregateDataContainer> _aggregate_data_container =
nullptr;
+ bool _use_simple_count = false;
bool _reach_limit = false;
size_t _input_num_rows = 0;
@@ -179,6 +181,11 @@ private:
// Do nothing
},
[&](auto& agg_method) -> void {
+ if (_use_simple_count) {
+ // Inline count: mapped slots hold UInt64,
+ // not real agg state pointers. Skip
destroy.
+ return;
+ }
auto& data = *agg_method.hash_table;
data.for_each_mapped([&](auto& mapped) {
if (mapped) {
diff --git a/be/src/exec/pipeline/dependency.h
b/be/src/exec/pipeline/dependency.h
index 018e699dbe6..3bebc9b55f7 100644
--- a/be/src/exec/pipeline/dependency.h
+++ b/be/src/exec/pipeline/dependency.h
@@ -320,6 +320,7 @@ public:
bool enable_spill = false;
bool reach_limit = false;
+ bool use_simple_count = false;
int64_t limit = -1;
bool do_sort_limit = false;
MutableColumns limit_columns;
@@ -392,6 +393,11 @@ private:
// Do nothing
},
[&](auto& agg_method) -> void {
+ if (use_simple_count) {
+ // Inline count: mapped slots hold UInt64,
+ // not real agg state pointers. Skip
destroy.
+ return;
+ }
auto& data = *agg_method.hash_table;
data.for_each_mapped([&](auto& mapped) {
if (mapped) {
diff --git a/be/src/exprs/aggregate/aggregate_function.h
b/be/src/exprs/aggregate/aggregate_function.h
index 475439cd39c..d7c97a3f944 100644
--- a/be/src/exprs/aggregate/aggregate_function.h
+++ b/be/src/exprs/aggregate/aggregate_function.h
@@ -263,6 +263,8 @@ public:
virtual bool is_blockable() const { return false; }
+ virtual bool is_simple_count() const { return false; }
+
/**
* Executes the aggregate function in incremental mode.
* This is a virtual function that should be overridden by aggregate
functions supporting incremental calculation.
diff --git a/be/src/exprs/aggregate/aggregate_function_count.h
b/be/src/exprs/aggregate/aggregate_function_count.h
index 35317a6240a..3bc825a4a5a 100644
--- a/be/src/exprs/aggregate/aggregate_function_count.h
+++ b/be/src/exprs/aggregate/aggregate_function_count.h
@@ -57,6 +57,7 @@ public:
AggregateFunctionCount(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper(argument_types_) {}
+ bool is_simple_count() const override { return true; }
String get_name() const override { return "count"; }
DataTypePtr get_return_type() const override { return
std::make_shared<DataTypeInt64>(); }
diff --git a/be/test/exec/hash_map/hash_table_method_test.cpp
b/be/test/exec/hash_map/hash_table_method_test.cpp
index a7d93d7780f..697ce283751 100644
--- a/be/test/exec/hash_map/hash_table_method_test.cpp
+++ b/be/test/exec/hash_map/hash_table_method_test.cpp
@@ -17,7 +17,10 @@
#include <gtest/gtest.h>
+#include <set>
+
#include "core/data_type/data_type_number.h"
+#include "exec/common/agg_utils.h"
#include "exec/common/columns_hashing.h"
#include "exec/common/hash_table/hash.h"
#include "exec/common/hash_table/hash_map_context.h"
@@ -127,4 +130,711 @@ TEST(HashTableMethodTest, testMethodStringNoCache) {
{0, 1, -1, 3, -1, 4});
}
-} // namespace doris
\ No newline at end of file
+// Verify that iterating a DataWithNullKey hash map via
init_iterator()/begin/end
+// does NOT visit the null key entry. The null key must be accessed separately
+// through has_null_key_data()/get_null_key_data().
+TEST(HashTableMethodTest, testNullableIteratorSkipsNullKey) {
+ using NullableMethod = MethodSingleNullableColumn<MethodOneNumber<
+ UInt32, DataWithNullKey<PHHashMap<UInt32, IColumn::ColumnIndex,
HashCRC32<UInt32>>>>>;
+ NullableMethod method;
+
+ // data: {1, 0(null), 2, 0(null), 3}
+ // null_map: {0, 1, 0, 1, 0} — positions 1 and 3 are null
+ auto nullable_col =
+ ColumnHelper::create_nullable_column<DataTypeInt32>({1, 0, 2, 0,
3}, {0, 1, 0, 1, 0});
+
+ // Insert all rows including nulls
+ {
+ using State = typename NullableMethod::State;
+ ColumnRawPtrs key_raw_columns {nullable_col.get()};
+ State state(key_raw_columns);
+ const size_t rows = nullable_col->size();
+ method.init_serialized_keys(key_raw_columns, rows);
+
+ for (size_t i = 0; i < rows; i++) {
+ IColumn::ColumnIndex mapped_value = i;
+ auto creator = [&](const auto& ctor, auto& key, auto& origin) {
+ ctor(key, mapped_value);
+ };
+ auto creator_for_null_key = [&](auto& mapped) { mapped =
mapped_value; };
+ method.lazy_emplace(state, i, creator, creator_for_null_key);
+ }
+ }
+
+ // hash_table->size() includes null key: 3 non-null + 1 null = 4
+ EXPECT_EQ(method.hash_table->size(), 4);
+
+ // The underlying hash map (excluding null) has 3 entries
+ EXPECT_TRUE(method.hash_table->has_null_key_data());
+
+ // Iterate via init_iterator — should only visit 3 non-null entries
+ method.init_iterator();
+ size_t iter_count = 0;
+ std::set<IColumn::ColumnIndex> visited_values;
+ auto iter = method.begin;
+ while (iter != method.end) {
+ visited_values.insert(iter.get_second());
+ ++iter;
+ ++iter_count;
+ }
+
+ // Iterator must visit exactly 3 entries (the non-null keys: 1, 2, 3)
+ EXPECT_EQ(iter_count, 3);
+ // Mapped values for non-null rows are 0 (key=1), 2 (key=2), 4 (key=3)
+ EXPECT_TRUE(visited_values.count(0)); // row 0: key=1
+ EXPECT_TRUE(visited_values.count(2)); // row 2: key=2
+ EXPECT_TRUE(visited_values.count(4)); // row 4: key=3
+ // The null key's mapped value (1, from the first null row) must NOT
appear in iteration
+ EXPECT_FALSE(visited_values.count(1));
+
+ // Null key must be accessible separately
+ auto null_mapped = method.hash_table->template
get_null_key_data<IColumn::ColumnIndex>();
+ EXPECT_EQ(null_mapped, 1); // first null insertion at row 1
+
+ // find should locate null keys correctly
+ {
+ using State = typename NullableMethod::State;
+ // Search: {1, null, 99, 2}
+ auto search_col =
+ ColumnHelper::create_nullable_column<DataTypeInt32>({1, 0, 99,
2}, {0, 1, 0, 0});
+ ColumnRawPtrs key_raw_columns {search_col.get()};
+ State state(key_raw_columns);
+ method.init_serialized_keys(key_raw_columns, 4);
+
+ // key=1 found
+ auto r0 = method.find(state, 0);
+ EXPECT_TRUE(r0.is_found());
+ EXPECT_EQ(r0.get_mapped(), 0);
+
+ // null found
+ auto r1 = method.find(state, 1);
+ EXPECT_TRUE(r1.is_found());
+ EXPECT_EQ(r1.get_mapped(), 1);
+
+ // key=99 not found
+ auto r2 = method.find(state, 2);
+ EXPECT_FALSE(r2.is_found());
+
+ // key=2 found
+ auto r3 = method.find(state, 3);
+ EXPECT_TRUE(r3.is_found());
+ EXPECT_EQ(r3.get_mapped(), 2);
+ }
+}
+
+// Helper: create distinguishable AggregateDataPtr values for testing
+static AggregateDataPtr make_mapped(size_t val) {
+ return reinterpret_cast<AggregateDataPtr>(val);
+}
+
+// ========== MethodOneNumber<UInt32, AggData<UInt32>> ==========
+// AggData<UInt32> = PHHashMap<UInt32, AggregateDataPtr, HashCRC32<UInt32>>
+TEST(HashTableMethodTest, testMethodOneNumberAggInsertFindForEach) {
+ MethodOneNumber<UInt32, AggData<UInt32>> method;
+ using State = MethodOneNumber<UInt32, AggData<UInt32>>::State;
+
+ auto col = ColumnHelper::create_column<DataTypeInt32>({10, 20, 30, 40,
50});
+ ColumnRawPtrs key_columns = {col.get()};
+ const size_t rows = 5;
+
+ // Insert
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) {
+ ctor(key, make_mapped(i + 1));
+ },
+ [](auto& mapped) { FAIL() << "unexpected null"; });
+ }
+ }
+
+ // Find existing keys
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ auto result = method.find(state, i);
+ ASSERT_TRUE(result.is_found());
+ EXPECT_EQ(result.get_mapped(), make_mapped(i + 1));
+ }
+ }
+
+ // Find non-existing key
+ {
+ auto miss_col = ColumnHelper::create_column<DataTypeInt32>({999});
+ ColumnRawPtrs miss_columns = {miss_col.get()};
+ State state(miss_columns);
+ method.init_serialized_keys(miss_columns, 1);
+ auto result = method.find(state, 0);
+ EXPECT_FALSE(result.is_found());
+ }
+
+ // for_each
+ {
+ size_t count = 0;
+ method.hash_table->for_each([&](const auto& key, auto& mapped) {
+ EXPECT_NE(mapped, nullptr);
+ count++;
+ });
+ EXPECT_EQ(count, 5);
+ }
+
+ // for_each_mapped
+ {
+ size_t count = 0;
+ method.hash_table->for_each_mapped([&](auto& mapped) {
+ EXPECT_NE(mapped, nullptr);
+ count++;
+ });
+ EXPECT_EQ(count, 5);
+ }
+}
+
+// ========== MethodOneNumber Phase2 (HashMixWrapper) ==========
+// AggregatedDataWithUInt32KeyPhase2 = PHHashMap<UInt32, AggregateDataPtr,
HashMixWrapper<UInt32>>
+TEST(HashTableMethodTest, testMethodOneNumberPhase2AggInsertFindForEach) {
+ MethodOneNumber<UInt32, AggregatedDataWithUInt32KeyPhase2> method;
+ using State = MethodOneNumber<UInt32,
AggregatedDataWithUInt32KeyPhase2>::State;
+
+ auto col = ColumnHelper::create_column<DataTypeInt32>({100, 200, 300});
+ ColumnRawPtrs key_columns = {col.get()};
+ const size_t rows = 3;
+
+ // Insert
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) {
+ ctor(key, make_mapped(i + 100));
+ },
+ [](auto& mapped) { FAIL(); });
+ }
+ }
+
+ // Find
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ auto result = method.find(state, i);
+ ASSERT_TRUE(result.is_found());
+ EXPECT_EQ(result.get_mapped(), make_mapped(i + 100));
+ }
+ }
+
+ // for_each + for_each_mapped
+ {
+ size_t count = 0;
+ method.hash_table->for_each([&](const auto& key, auto& mapped) {
count++; });
+ EXPECT_EQ(count, 3);
+ }
+ {
+ size_t count = 0;
+ method.hash_table->for_each_mapped([&](auto& mapped) { count++; });
+ EXPECT_EQ(count, 3);
+ }
+}
+
+// ========== MethodStringNoCache<AggregatedDataWithShortStringKey> ==========
+// AggregatedDataWithShortStringKey = StringHashMap<AggregateDataPtr>
+TEST(HashTableMethodTest, testMethodStringNoCacheAggInsertFindForEach) {
+ MethodStringNoCache<AggregatedDataWithShortStringKey> method;
+ using State = MethodStringNoCache<AggregatedDataWithShortStringKey>::State;
+
+ // Include strings of varying lengths to exercise different StringHashMap
sub-maps
+ auto col = ColumnHelper::create_column<DataTypeString>(
+ {"hello", "world", "foo", "bar", "longstring_exceeding_16_bytes"});
+ ColumnRawPtrs key_columns = {col.get()};
+ const size_t rows = 5;
+
+ // Insert
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) {
+ ctor(key, make_mapped(i + 10));
+ },
+ [](auto& mapped) { FAIL(); });
+ }
+ }
+
+ // Find
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ auto result = method.find(state, i);
+ ASSERT_TRUE(result.is_found());
+ EXPECT_EQ(result.get_mapped(), make_mapped(i + 10));
+ }
+ }
+
+ // for_each
+ {
+ size_t count = 0;
+ method.hash_table->for_each([&](const auto& key, auto& mapped) {
+ EXPECT_NE(mapped, nullptr);
+ count++;
+ });
+ EXPECT_EQ(count, 5);
+ }
+
+ // for_each_mapped
+ {
+ size_t count = 0;
+ method.hash_table->for_each_mapped([&](auto& mapped) { count++; });
+ EXPECT_EQ(count, 5);
+ }
+}
+
+// ========== MethodSerialized<AggregatedDataWithStringKey> ==========
+// AggregatedDataWithStringKey = PHHashMap<StringRef, AggregateDataPtr>
+// StringRef keys require arena persistence to survive across
init_serialized_keys calls.
+TEST(HashTableMethodTest, testMethodSerializedAggInsertFindForEach) {
+ MethodSerialized<AggregatedDataWithStringKey> method;
+ using State = MethodSerialized<AggregatedDataWithStringKey>::State;
+
+ auto col1 = ColumnHelper::create_column<DataTypeInt32>({1, 2, 3});
+ auto col2 = ColumnHelper::create_column<DataTypeString>({"a", "bb",
"ccc"});
+ ColumnRawPtrs key_columns = {col1.get(), col2.get()};
+ const size_t rows = 3;
+
+ // Use a separate arena to persist StringRef key data
+ Arena persist_arena;
+
+ // Insert
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) {
+ method.try_presis_key_and_origin(key, origin,
persist_arena);
+ ctor(key, make_mapped(i + 50));
+ },
+ [](auto& mapped) { FAIL(); });
+ }
+ }
+
+ // for_each (keys backed by persist_arena)
+ {
+ size_t count = 0;
+ method.hash_table->for_each([&](const auto& key, auto& mapped) {
+ EXPECT_GT(key.size, 0);
+ EXPECT_NE(mapped, nullptr);
+ count++;
+ });
+ EXPECT_EQ(count, 3);
+ }
+
+ // for_each_mapped
+ {
+ size_t count = 0;
+ method.hash_table->for_each_mapped([&](auto& mapped) {
+ EXPECT_NE(mapped, nullptr);
+ count++;
+ });
+ EXPECT_EQ(count, 3);
+ }
+
+ // Find (re-init serialized keys for lookup)
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ auto result = method.find(state, i);
+ ASSERT_TRUE(result.is_found());
+ EXPECT_EQ(result.get_mapped(), make_mapped(i + 50));
+ }
+ }
+}
+
+// ========== MethodKeysFixed<AggData<UInt64>> ==========
+// Fixed-width multi-column keys packed into UInt64
+TEST(HashTableMethodTest, testMethodKeysFixedAggInsertFindForEach) {
+ MethodKeysFixed<AggData<UInt64>> method(Sizes {sizeof(int32_t),
sizeof(int32_t)});
+ using State = MethodKeysFixed<AggData<UInt64>>::State;
+
+ auto col1 = ColumnHelper::create_column<DataTypeInt32>({1, 2, 3, 4});
+ auto col2 = ColumnHelper::create_column<DataTypeInt32>({10, 20, 30, 40});
+ ColumnRawPtrs key_columns = {col1.get(), col2.get()};
+ const size_t rows = 4;
+
+ // Insert
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) {
+ ctor(key, make_mapped(i + 200));
+ },
+ [](auto& mapped) { FAIL(); });
+ }
+ }
+
+ // Find
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ auto result = method.find(state, i);
+ ASSERT_TRUE(result.is_found());
+ EXPECT_EQ(result.get_mapped(), make_mapped(i + 200));
+ }
+ }
+
+ // for_each
+ {
+ size_t count = 0;
+ method.hash_table->for_each([&](const auto& key, auto& mapped) {
+ EXPECT_NE(mapped, nullptr);
+ count++;
+ });
+ EXPECT_EQ(count, 4);
+ }
+}
+
+// ========== Nullable MethodOneNumber (MethodSingleNullableColumn +
MethodOneNumber) ==========
+// AggDataNullable<UInt32> = DataWithNullKey<PHHashMap<UInt32,
AggregateDataPtr, HashCRC32<UInt32>>>
+// Tests null key insertion, find, and for_each (which excludes null from
PHHashMap::for_each).
+TEST(HashTableMethodTest, testNullableMethodOneNumberAggInsertFindForEach) {
+ using NullableMethod =
+ MethodSingleNullableColumn<MethodOneNumber<UInt32,
AggDataNullable<UInt32>>>;
+ NullableMethod method;
+ using State = NullableMethod::State;
+
+ // values: {10, 20, 30, 40, 50}, null at rows 1 and 3
+ auto col = ColumnHelper::create_nullable_column<DataTypeInt32>({10, 20,
30, 40, 50},
+ {0, 1, 0,
1, 0});
+ ColumnRawPtrs key_columns = {col.get()};
+ const size_t rows = 5;
+
+ // Insert
+ size_t null_create_count = 0;
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) {
+ ctor(key, make_mapped(i + 1));
+ },
+ [&](auto& mapped) {
+ null_create_count++;
+ mapped = make_mapped(999);
+ });
+ }
+ }
+
+ // null_creator called once for first null row (index 1); second null
(index 3) is deduplicated
+ EXPECT_EQ(null_create_count, 1);
+
+ auto& ht = *method.hash_table;
+ EXPECT_TRUE(ht.has_null_key_data());
+ EXPECT_EQ(ht.get_null_key_data<AggregateDataPtr>(), make_mapped(999));
+
+ // Find
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+
+ // row 0: key=10, non-null
+ auto r0 = method.find(state, 0);
+ ASSERT_TRUE(r0.is_found());
+ EXPECT_EQ(r0.get_mapped(), make_mapped(1));
+
+ // row 1: null → returns null key data
+ auto r1 = method.find(state, 1);
+ ASSERT_TRUE(r1.is_found());
+ EXPECT_EQ(r1.get_mapped(), make_mapped(999));
+
+ // row 2: key=30, non-null
+ auto r2 = method.find(state, 2);
+ ASSERT_TRUE(r2.is_found());
+ EXPECT_EQ(r2.get_mapped(), make_mapped(3));
+
+ // row 4: key=50, non-null
+ auto r4 = method.find(state, 4);
+ ASSERT_TRUE(r4.is_found());
+ EXPECT_EQ(r4.get_mapped(), make_mapped(5));
+ }
+
+ // for_each_mapped: PHHashMap::for_each_mapped iterates only non-null keys
+ {
+ size_t count = 0;
+ ht.for_each_mapped([&](auto& mapped) { count++; });
+ EXPECT_EQ(count, 3); // keys 10, 30, 50
+ }
+
+ // DataWithNullKey::size() includes null key
+ EXPECT_EQ(ht.size(), 4); // 3 non-null + 1 null
+}
+
+// ========== Nullable MethodStringNoCache ==========
+// AggregatedDataWithNullableShortStringKey =
DataWithNullKey<StringHashMap<AggregateDataPtr>>
+TEST(HashTableMethodTest, testNullableMethodStringNoCacheAggInsertFindForEach)
{
+ using NullableMethod = MethodSingleNullableColumn<
+ MethodStringNoCache<AggregatedDataWithNullableShortStringKey>>;
+ NullableMethod method;
+ using State = NullableMethod::State;
+
+ // values: {"hello", <null>, "world", <null>}
+ auto col = ColumnHelper::create_nullable_column<DataTypeString>({"hello",
"", "world", ""},
+ {0, 1, 0,
1});
+ ColumnRawPtrs key_columns = {col.get()};
+ const size_t rows = 4;
+
+ // Insert
+ size_t null_create_count = 0;
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) {
+ ctor(key, make_mapped(i + 1));
+ },
+ [&](auto& mapped) {
+ null_create_count++;
+ mapped = make_mapped(888);
+ });
+ }
+ }
+
+ EXPECT_EQ(null_create_count, 1);
+
+ auto& ht = *method.hash_table;
+ EXPECT_TRUE(ht.has_null_key_data());
+ EXPECT_EQ(ht.get_null_key_data<AggregateDataPtr>(), make_mapped(888));
+
+ // Find
+ {
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+
+ // row 0: "hello"
+ auto r0 = method.find(state, 0);
+ ASSERT_TRUE(r0.is_found());
+ EXPECT_EQ(r0.get_mapped(), make_mapped(1));
+
+ // row 1: null
+ auto r1 = method.find(state, 1);
+ ASSERT_TRUE(r1.is_found());
+ EXPECT_EQ(r1.get_mapped(), make_mapped(888));
+
+ // row 2: "world"
+ auto r2 = method.find(state, 2);
+ ASSERT_TRUE(r2.is_found());
+ EXPECT_EQ(r2.get_mapped(), make_mapped(3));
+ }
+
+ // for_each: StringHashMap::for_each iterates only non-null keys
+ {
+ size_t count = 0;
+ ht.for_each([&](const auto& key, auto& mapped) { count++; });
+ EXPECT_EQ(count, 2); // "hello" and "world"
+ }
+
+ // DataWithNullKey::size() includes null key
+ EXPECT_EQ(ht.size(), 3); // 2 non-null + 1 null
+}
+
+// ========== PHHashMap iterator: traverse, sort, verify, assignment ==========
+TEST(HashTableMethodTest, testPHHashMapIterator) {
+ MethodOneNumber<UInt32, AggData<UInt32>> method;
+ using State = MethodOneNumber<UInt32, AggData<UInt32>>::State;
+
+ auto col = ColumnHelper::create_column<DataTypeInt32>({50, 10, 40, 20,
30});
+ ColumnRawPtrs key_columns = {col.get()};
+ const size_t rows = 5;
+
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) { ctor(key,
make_mapped(i + 1)); },
+ [](auto& mapped) { FAIL(); });
+ }
+
+ auto& ht = *method.hash_table;
+
+ // Collect all (key, mapped) pairs via iterator
+ std::vector<std::pair<UInt32, AggregateDataPtr>> entries;
+ for (auto it = ht.begin(); it != ht.end(); ++it) {
+ entries.emplace_back(it->get_first(), it->get_second());
+ }
+ ASSERT_EQ(entries.size(), 5);
+
+ // Sort by key and verify
+ std::sort(entries.begin(), entries.end(),
+ [](const auto& a, const auto& b) { return a.first < b.first; });
+ // Inserted: {50→1, 10→2, 40→3, 20→4, 30→5}
+ EXPECT_EQ(entries[0].first, 10);
+ EXPECT_EQ(entries[0].second, make_mapped(2));
+ EXPECT_EQ(entries[1].first, 20);
+ EXPECT_EQ(entries[1].second, make_mapped(4));
+ EXPECT_EQ(entries[2].first, 30);
+ EXPECT_EQ(entries[2].second, make_mapped(5));
+ EXPECT_EQ(entries[3].first, 40);
+ EXPECT_EQ(entries[3].second, make_mapped(3));
+ EXPECT_EQ(entries[4].first, 50);
+ EXPECT_EQ(entries[4].second, make_mapped(1));
+
+ // Iterator assignment: it = begin(), it2 = it
+ auto it = ht.begin();
+ auto it2 = it; // copy
+ EXPECT_TRUE(it == it2);
+ EXPECT_EQ(it->get_first(), it2->get_first());
+ EXPECT_EQ(it->get_second(), it2->get_second());
+
+ ++it;
+ EXPECT_TRUE(it != it2); // diverged after increment
+
+ it2 = it; // reassignment
+ EXPECT_TRUE(it == it2);
+ EXPECT_EQ(it->get_first(), it2->get_first());
+
+ // Empty hash table: begin == end
+ MethodOneNumber<UInt32, AggData<UInt32>> empty_method;
+ EXPECT_TRUE(empty_method.hash_table->begin() ==
empty_method.hash_table->end());
+}
+
+// ========== StringHashMap iterator: traverse, sort, verify, assignment
==========
+TEST(HashTableMethodTest, testStringHashMapIterator) {
+ MethodStringNoCache<AggregatedDataWithShortStringKey> method;
+ using State = MethodStringNoCache<AggregatedDataWithShortStringKey>::State;
+
+ // Different lengths to hit different sub-maps (m1: <=8, m2: <=16, m3:
<=24, ms: >24)
+ auto col = ColumnHelper::create_column<DataTypeString>(
+ {"z", "ab", "hello_world_12345", "tiny",
"a_very_long_string_over_24_chars"});
+ ColumnRawPtrs key_columns = {col.get()};
+ const size_t rows = 5;
+
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) { ctor(key,
make_mapped(i + 1)); },
+ [](auto& mapped) { FAIL(); });
+ }
+
+ auto& ht = *method.hash_table;
+
+ // Collect all (key_string, mapped) pairs via iterator
+ std::vector<std::pair<std::string, AggregateDataPtr>> entries;
+ for (auto it = ht.begin(); it != ht.end(); ++it) {
+ auto key = it.get_first();
+ entries.emplace_back(std::string(key.data, key.size), it.get_second());
+ }
+ ASSERT_EQ(entries.size(), 5);
+
+ // Sort by key string and verify
+ std::sort(entries.begin(), entries.end(),
+ [](const auto& a, const auto& b) { return a.first < b.first; });
+ // Sorted: "a_very_long...", "ab", "hello_world_12345", "tiny", "z"
+ EXPECT_EQ(entries[0].first, "a_very_long_string_over_24_chars");
+ EXPECT_EQ(entries[0].second, make_mapped(5));
+ EXPECT_EQ(entries[1].first, "ab");
+ EXPECT_EQ(entries[1].second, make_mapped(2));
+ EXPECT_EQ(entries[2].first, "hello_world_12345");
+ EXPECT_EQ(entries[2].second, make_mapped(3));
+ EXPECT_EQ(entries[3].first, "tiny");
+ EXPECT_EQ(entries[3].second, make_mapped(4));
+ EXPECT_EQ(entries[4].first, "z");
+ EXPECT_EQ(entries[4].second, make_mapped(1));
+
+ // Iterator assignment: copy and reassign
+ auto it = ht.begin();
+ auto it2 = it;
+ EXPECT_TRUE(it == it2);
+
+ auto key1 = it.get_first();
+ auto key2 = it2.get_first();
+ EXPECT_EQ(key1, key2);
+
+ ++it;
+ EXPECT_TRUE(it != it2);
+
+ it2 = it; // reassignment
+ EXPECT_TRUE(it == it2);
+
+ // Empty StringHashMap: begin == end
+ MethodStringNoCache<AggregatedDataWithShortStringKey> empty_method;
+ EXPECT_TRUE(empty_method.hash_table->begin() ==
empty_method.hash_table->end());
+}
+
+// ========== DataWithNullKey iterator: only non-null entries, null key
accessed separately ==========
+TEST(HashTableMethodTest, testDataWithNullKeyIterator) {
+ using NullableMethod =
+ MethodSingleNullableColumn<MethodOneNumber<UInt32,
AggDataNullable<UInt32>>>;
+ NullableMethod method;
+ using State = NullableMethod::State;
+
+ // values: {10, 20, 30}, null at row 1
+ auto col = ColumnHelper::create_nullable_column<DataTypeInt32>({10, 20,
30}, {0, 1, 0});
+ ColumnRawPtrs key_columns = {col.get()};
+ const size_t rows = 3;
+
+ State state(key_columns);
+ method.init_serialized_keys(key_columns, rows);
+ for (size_t i = 0; i < rows; i++) {
+ method.lazy_emplace(
+ state, i,
+ [&](const auto& ctor, auto& key, auto& origin) { ctor(key,
make_mapped(i + 1)); },
+ [&](auto& mapped) { mapped = make_mapped(999); });
+ }
+
+ auto& ht = *method.hash_table;
+
+ // Null key is present and must be accessed separately (not via iterator)
+ EXPECT_TRUE(ht.has_null_key_data());
+ EXPECT_EQ(ht.get_null_key_data<AggregateDataPtr>(), make_mapped(999));
+
+ // DataWithNullKey::size() includes null key
+ EXPECT_EQ(ht.size(), 3); // 2 non-null + 1 null
+
+ // Iterator only visits non-null entries
+ std::vector<std::pair<UInt32, AggregateDataPtr>> non_null_entries;
+ for (auto it = ht.begin(); it != ht.end(); ++it) {
+ non_null_entries.emplace_back(it->get_first(), it->get_second());
+ }
+
+ // Only 2 non-null entries in iteration (null key excluded)
+ ASSERT_EQ(non_null_entries.size(), 2);
+ std::sort(non_null_entries.begin(), non_null_entries.end(),
+ [](const auto& a, const auto& b) { return a.first < b.first; });
+ // Inserted: row0=10→1, row2=30→3
+ EXPECT_EQ(non_null_entries[0].first, 10);
+ EXPECT_EQ(non_null_entries[0].second, make_mapped(1));
+ EXPECT_EQ(non_null_entries[1].first, 30);
+ EXPECT_EQ(non_null_entries[1].second, make_mapped(3));
+
+ // Iterator assignment
+ auto it = ht.begin();
+ auto it2 = it;
+ EXPECT_TRUE(it == it2);
+ EXPECT_EQ(it.get_second(), it2.get_second());
+
+ ++it;
+ EXPECT_TRUE(it != it2);
+
+ it2 = it;
+ EXPECT_TRUE(it == it2);
+}
+} // namespace doris
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]