http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/c123bd49/expressions/aggregation/AggregationHandleMin.hpp ---------------------------------------------------------------------- diff --git a/expressions/aggregation/AggregationHandleMin.hpp b/expressions/aggregation/AggregationHandleMin.hpp index 4e0c72b..4a0eca4 100644 --- a/expressions/aggregation/AggregationHandleMin.hpp +++ b/expressions/aggregation/AggregationHandleMin.hpp @@ -28,8 +28,8 @@ #include "catalog/CatalogTypedefs.hpp" #include "expressions/aggregation/AggregationConcreteHandle.hpp" #include "expressions/aggregation/AggregationHandle.hpp" -#include "storage/HashTableBase.hpp" #include "storage/FastHashTable.hpp" +#include "storage/HashTableBase.hpp" #include "threading/SpinMutex.hpp" #include "types/Type.hpp" #include "types/TypedValue.hpp" @@ -56,19 +56,18 @@ class AggregationStateMin : public AggregationState { /** * @brief Copy constructor (ignores mutex). */ - AggregationStateMin(const AggregationStateMin &orig) - : min_(orig.min_) { - } + AggregationStateMin(const AggregationStateMin &orig) : min_(orig.min_) {} /** * @brief Destructor. */ ~AggregationStateMin() override {} - size_t getPayloadSize() const { - return sizeof(TypedValue); - } + std::size_t getPayloadSize() const { return sizeof(TypedValue); } + const std::uint8_t *getPayloadAddress() const { + return reinterpret_cast<const uint8_t *>(&min_); + } private: friend class AggregationHandleMin; @@ -76,9 +75,7 @@ class AggregationStateMin : public AggregationState { explicit AggregationStateMin(const Type &type) : min_(type.getNullableVersion().makeNullValue()) {} - explicit AggregationStateMin(TypedValue &&value) - : min_(std::move(value)) { - } + explicit AggregationStateMin(TypedValue &&value) : min_(std::move(value)) {} TypedValue min_; SpinMutex mutex_; @@ -89,8 +86,7 @@ class AggregationStateMin : public AggregationState { **/ class AggregationHandleMin : public AggregationConcreteHandle { public: - ~AggregationHandleMin() override { - } + ~AggregationHandleMin() override {} AggregationState* createInitialState() const override { return new AggregationStateMin(type_); @@ -98,45 +94,46 @@ class AggregationHandleMin : public AggregationConcreteHandle { AggregationStateHashTableBase* createGroupByHashTable( const HashTableImplType hash_table_impl, - const std::vector<const Type*> &group_by_types, + const std::vector<const Type *> &group_by_types, const std::size_t estimated_num_groups, StorageManager *storage_manager) const override; /** * @brief Iterate with min aggregation state. */ - inline void iterateUnaryInl(AggregationStateMin *state, const TypedValue &value) const { + inline void iterateUnaryInl(AggregationStateMin *state, + const TypedValue &value) const { DCHECK(value.isPlausibleInstanceOf(type_.getSignature())); compareAndUpdate(state, value); } - inline void iterateUnaryInlFast(const TypedValue &value, uint8_t *byte_ptr) const { - DCHECK(value.isPlausibleInstanceOf(type_.getSignature())); - TypedValue *min_ptr = reinterpret_cast<TypedValue *>(byte_ptr); - compareAndUpdateFast(min_ptr, value); + inline void iterateUnaryInlFast(const TypedValue &value, + std::uint8_t *byte_ptr) const { + DCHECK(value.isPlausibleInstanceOf(type_.getSignature())); + TypedValue *min_ptr = reinterpret_cast<TypedValue *>(byte_ptr); + compareAndUpdateFast(min_ptr, value); } - inline void iterateInlFast(const std::vector<TypedValue> &arguments, uint8_t *byte_ptr) const override { - if (block_update) return; - iterateUnaryInlFast(arguments.front(), byte_ptr); + inline void updateState(const std::vector<TypedValue> &arguments, + std::uint8_t *byte_ptr) const override { + if (!block_update_) { + iterateUnaryInlFast(arguments.front(), byte_ptr); + } } - void BlockUpdate() override { - block_update = true; - } + void blockUpdate() override { block_update_ = true; } - void AllowUpdate() override { - block_update = false; - } + void allowUpdate() override { block_update_ = false; } - void initPayload(uint8_t *byte_ptr) const override { + void initPayload(std::uint8_t *byte_ptr) const override { TypedValue *min_ptr = reinterpret_cast<TypedValue *>(byte_ptr); TypedValue t1 = (type_.getNullableVersion().makeNullValue()); *min_ptr = t1; } AggregationState* accumulateColumnVectors( - const std::vector<std::unique_ptr<ColumnVector>> &column_vectors) const override; + const std::vector<std::unique_ptr<ColumnVector>> &column_vectors) + const override; #ifdef QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION AggregationState* accumulateValueAccessor( @@ -153,18 +150,20 @@ class AggregationHandleMin : public AggregationConcreteHandle { void mergeStates(const AggregationState &source, AggregationState *destination) const override; - void mergeStatesFast(const uint8_t *source, - uint8_t *destination) const override; + void mergeStatesFast(const std::uint8_t *source, + std::uint8_t *destination) const override; TypedValue finalize(const AggregationState &state) const override { - return static_cast<const AggregationStateMin&>(state).min_; + return static_cast<const AggregationStateMin &>(state).min_; } - inline TypedValue finalizeHashTableEntry(const AggregationState &state) const { - return static_cast<const AggregationStateMin&>(state).min_; + inline TypedValue finalizeHashTableEntry( + const AggregationState &state) const { + return static_cast<const AggregationStateMin &>(state).min_; } - inline TypedValue finalizeHashTableEntryFast(const std::uint8_t *byte_ptr) const { + inline TypedValue finalizeHashTableEntryFast( + const std::uint8_t *byte_ptr) const { const TypedValue *min_ptr = reinterpret_cast<const TypedValue *>(byte_ptr); return TypedValue(*min_ptr); } @@ -175,24 +174,25 @@ class AggregationHandleMin : public AggregationConcreteHandle { int index) const override; /** - * @brief Implementation of AggregationHandle::aggregateOnDistinctifyHashTableForSingle() + * @brief Implementation of + * AggregationHandle::aggregateOnDistinctifyHashTableForSingle() * for MIN aggregation. */ AggregationState* aggregateOnDistinctifyHashTableForSingle( - const AggregationStateHashTableBase &distinctify_hash_table) const override; + const AggregationStateHashTableBase &distinctify_hash_table) + const override; /** - * @brief Implementation of AggregationHandle::aggregateOnDistinctifyHashTableForGroupBy() + * @brief Implementation of + * AggregationHandle::aggregateOnDistinctifyHashTableForGroupBy() * for MIN aggregation. */ void aggregateOnDistinctifyHashTableForGroupBy( const AggregationStateHashTableBase &distinctify_hash_table, AggregationStateHashTableBase *aggregation_hash_table, - int index) const override; + std::size_t index) const override; - size_t getPayloadSize() const override { - return sizeof(TypedValue); - } + std::size_t getPayloadSize() const override { return sizeof(TypedValue); } private: friend class AggregateFunctionMin; @@ -205,23 +205,28 @@ class AggregationHandleMin : public AggregationConcreteHandle { explicit AggregationHandleMin(const Type &type); /** - * @brief compare the value with min_ and update it if the value is smaller than + * @brief compare the value with min_ and update it if the value is smaller + *than * current minimum. NULLs are ignored. * * @param value A TypedValue to compare. **/ - inline void compareAndUpdate(AggregationStateMin *state, const TypedValue &value) const { + inline void compareAndUpdate(AggregationStateMin *state, + const TypedValue &value) const { if (value.isNull()) return; SpinMutexLock lock(state->mutex_); - if (state->min_.isNull() || fast_comparator_->compareTypedValues(value, state->min_)) { + if (state->min_.isNull() || + fast_comparator_->compareTypedValues(value, state->min_)) { state->min_ = value; } } - inline void compareAndUpdateFast(TypedValue *min_ptr, const TypedValue &value) const { + inline void compareAndUpdateFast(TypedValue *min_ptr, + const TypedValue &value) const { if (value.isNull()) return; - if (min_ptr->isNull() || fast_comparator_->compareTypedValues(value, *min_ptr)) { + if (min_ptr->isNull() || + fast_comparator_->compareTypedValues(value, *min_ptr)) { *min_ptr = value; } } @@ -229,7 +234,7 @@ class AggregationHandleMin : public AggregationConcreteHandle { const Type &type_; std::unique_ptr<UncheckedComparator> fast_comparator_; - bool block_update; + bool block_update_; DISALLOW_COPY_AND_ASSIGN(AggregationHandleMin); };
http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/c123bd49/expressions/aggregation/AggregationHandleSum.cpp ---------------------------------------------------------------------- diff --git a/expressions/aggregation/AggregationHandleSum.cpp b/expressions/aggregation/AggregationHandleSum.cpp index 1b0bbcd..642d88d 100644 --- a/expressions/aggregation/AggregationHandleSum.cpp +++ b/expressions/aggregation/AggregationHandleSum.cpp @@ -43,7 +43,7 @@ namespace quickstep { class StorageManager; AggregationHandleSum::AggregationHandleSum(const Type &type) - : argument_type_(type), block_update(false) { + : argument_type_(type), block_update_(false) { // We sum Int as Long and Float as Double so that we have more headroom when // adding many values. TypeID type_precision_id; @@ -66,11 +66,13 @@ AggregationHandleSum::AggregationHandleSum(const Type &type) // Make operators to do arithmetic: // Add operator for summing argument values. - fast_operator_.reset(BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd) - .makeUncheckedBinaryOperatorForTypes(sum_type, argument_type_)); + fast_operator_.reset( + BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd) + .makeUncheckedBinaryOperatorForTypes(sum_type, argument_type_)); // Add operator for merging states. - merge_operator_.reset(BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd) - .makeUncheckedBinaryOperatorForTypes(sum_type, sum_type)); + merge_operator_.reset( + BinaryOperationFactory::GetBinaryOperation(BinaryOperationID::kAdd) + .makeUncheckedBinaryOperatorForTypes(sum_type, sum_type)); // Result is nullable, because SUM() over 0 values (or all NULL values) is // NULL. @@ -79,14 +81,11 @@ AggregationHandleSum::AggregationHandleSum(const Type &type) AggregationStateHashTableBase* AggregationHandleSum::createGroupByHashTable( const HashTableImplType hash_table_impl, - const std::vector<const Type*> &group_by_types, + const std::vector<const Type *> &group_by_types, const std::size_t estimated_num_groups, StorageManager *storage_manager) const { return AggregationStateHashTableFactory<AggregationStateSum>::CreateResizable( - hash_table_impl, - group_by_types, - estimated_num_groups, - storage_manager); + hash_table_impl, group_by_types, estimated_num_groups, storage_manager); } AggregationState* AggregationHandleSum::accumulateColumnVectors( @@ -95,9 +94,7 @@ AggregationState* AggregationHandleSum::accumulateColumnVectors( << "Got wrong number of ColumnVectors for SUM: " << column_vectors.size(); std::size_t num_tuples = 0; TypedValue cv_sum = fast_operator_->accumulateColumnVector( - blank_state_.sum_, - *column_vectors.front(), - &num_tuples); + blank_state_.sum_, *column_vectors.front(), &num_tuples); return new AggregationStateSum(std::move(cv_sum), num_tuples == 0); } @@ -110,10 +107,7 @@ AggregationState* AggregationHandleSum::accumulateValueAccessor( std::size_t num_tuples = 0; TypedValue va_sum = fast_operator_->accumulateValueAccessor( - blank_state_.sum_, - accessor, - accessor_ids.front(), - &num_tuples); + blank_state_.sum_, accessor, accessor_ids.front(), &num_tuples); return new AggregationStateSum(std::move(va_sum), num_tuples == 0); } #endif @@ -127,31 +121,37 @@ void AggregationHandleSum::aggregateValueAccessorIntoHashTable( << "Got wrong number of arguments for SUM: " << argument_ids.size(); } -void AggregationHandleSum::mergeStates( - const AggregationState &source, - AggregationState *destination) const { - const AggregationStateSum &sum_source = static_cast<const AggregationStateSum&>(source); - AggregationStateSum *sum_destination = static_cast<AggregationStateSum*>(destination); +void AggregationHandleSum::mergeStates(const AggregationState &source, + AggregationState *destination) const { + const AggregationStateSum &sum_source = + static_cast<const AggregationStateSum &>(source); + AggregationStateSum *sum_destination = + static_cast<AggregationStateSum *>(destination); SpinMutexLock lock(sum_destination->mutex_); - sum_destination->sum_ = merge_operator_->applyToTypedValues(sum_destination->sum_, - sum_source.sum_); + sum_destination->sum_ = merge_operator_->applyToTypedValues( + sum_destination->sum_, sum_source.sum_); sum_destination->null_ = sum_destination->null_ && sum_source.null_; } -void AggregationHandleSum::mergeStatesFast( - const uint8_t *source, - uint8_t *destination) const { - const TypedValue *src_sum_ptr = reinterpret_cast<const TypedValue *>(source+blank_state_.sum_offset); - const bool *src_null_ptr = reinterpret_cast<const bool *>(source+blank_state_.null_offset); - TypedValue *dst_sum_ptr = reinterpret_cast<TypedValue *>(destination+blank_state_.sum_offset); - bool *dst_null_ptr = reinterpret_cast<bool *>(destination+blank_state_.null_offset); - *dst_sum_ptr = merge_operator_->applyToTypedValues(*dst_sum_ptr, *src_sum_ptr); - *dst_null_ptr = (*dst_null_ptr) && (*src_null_ptr); +void AggregationHandleSum::mergeStatesFast(const std::uint8_t *source, + std::uint8_t *destination) const { + const TypedValue *src_sum_ptr = + reinterpret_cast<const TypedValue *>(source + blank_state_.sum_offset_); + const bool *src_null_ptr = + reinterpret_cast<const bool *>(source + blank_state_.null_offset_); + TypedValue *dst_sum_ptr = + reinterpret_cast<TypedValue *>(destination + blank_state_.sum_offset_); + bool *dst_null_ptr = + reinterpret_cast<bool *>(destination + blank_state_.null_offset_); + *dst_sum_ptr = + merge_operator_->applyToTypedValues(*dst_sum_ptr, *src_sum_ptr); + *dst_null_ptr = (*dst_null_ptr) && (*src_null_ptr); } TypedValue AggregationHandleSum::finalize(const AggregationState &state) const { - const AggregationStateSum &agg_state = static_cast<const AggregationStateSum&>(state); + const AggregationStateSum &agg_state = + static_cast<const AggregationStateSum &>(state); if (agg_state.null_) { // SUM() over no values is NULL. return result_type_->makeNullValue(); @@ -165,31 +165,26 @@ ColumnVector* AggregationHandleSum::finalizeHashTable( std::vector<std::vector<TypedValue>> *group_by_keys, int index) const { return finalizeHashTableHelperFast<AggregationHandleSum, - AggregationStateFastHashTable>( - *result_type_, - hash_table, - group_by_keys, - index); + AggregationStateFastHashTable>( + *result_type_, hash_table, group_by_keys, index); } -AggregationState* AggregationHandleSum::aggregateOnDistinctifyHashTableForSingle( +AggregationState* +AggregationHandleSum::aggregateOnDistinctifyHashTableForSingle( const AggregationStateHashTableBase &distinctify_hash_table) const { return aggregateOnDistinctifyHashTableForSingleUnaryHelperFast< AggregationHandleSum, - AggregationStateSum>( - distinctify_hash_table); + AggregationStateSum>(distinctify_hash_table); } void AggregationHandleSum::aggregateOnDistinctifyHashTableForGroupBy( const AggregationStateHashTableBase &distinctify_hash_table, AggregationStateHashTableBase *aggregation_hash_table, - int index) const { + std::size_t index) const { aggregateOnDistinctifyHashTableForGroupByUnaryHelperFast< AggregationHandleSum, AggregationStateFastHashTable>( - distinctify_hash_table, - aggregation_hash_table, - index); + distinctify_hash_table, aggregation_hash_table, index); } } // namespace quickstep http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/c123bd49/expressions/aggregation/AggregationHandleSum.hpp ---------------------------------------------------------------------- diff --git a/expressions/aggregation/AggregationHandleSum.hpp b/expressions/aggregation/AggregationHandleSum.hpp index 3e1de48..8d719ab 100644 --- a/expressions/aggregation/AggregationHandleSum.hpp +++ b/expressions/aggregation/AggregationHandleSum.hpp @@ -28,8 +28,8 @@ #include "catalog/CatalogTypedefs.hpp" #include "expressions/aggregation/AggregationConcreteHandle.hpp" #include "expressions/aggregation/AggregationHandle.hpp" -#include "storage/HashTableBase.hpp" #include "storage/FastHashTable.hpp" +#include "storage/HashTableBase.hpp" #include "threading/SpinMutex.hpp" #include "types/Type.hpp" #include "types/TypedValue.hpp" @@ -59,27 +59,31 @@ class AggregationStateSum : public AggregationState { AggregationStateSum(const AggregationStateSum &orig) : sum_(orig.sum_), null_(orig.null_), - sum_offset(orig.sum_offset), - null_offset(orig.null_offset) { + sum_offset_(orig.sum_offset_), + null_offset_(orig.null_offset_) {} + + std::size_t getPayloadSize() const { + std::size_t p1 = reinterpret_cast<std::size_t>(&sum_); + std::size_t p2 = reinterpret_cast<std::size_t>(&mutex_); + return (p2 - p1); + } + + const std::uint8_t* getPayloadAddress() const { + return reinterpret_cast<const uint8_t *>(&sum_); } private: friend class AggregationHandleSum; AggregationStateSum() - : sum_(0), null_(true), sum_offset(0), - null_offset(reinterpret_cast<uint8_t *>(&null_)-reinterpret_cast<uint8_t *>(&sum_)) { - } + : sum_(0), + null_(true), + sum_offset_(0), + null_offset_(reinterpret_cast<std::uint8_t *>(&null_) - + reinterpret_cast<std::uint8_t *>(&sum_)) {} AggregationStateSum(TypedValue &&sum, const bool is_null) - : sum_(std::move(sum)), null_(is_null) { - } - - size_t getPayloadSize() const { - size_t p1 = reinterpret_cast<size_t>(&sum_); - size_t p2 = reinterpret_cast<size_t>(&mutex_); - return (p2-p1); - } + : sum_(std::move(sum)), null_(is_null) {} // TODO(shoban): We might want to specialize sum_ to use atomics for int types // similar to in AggregationStateCount. @@ -87,17 +91,15 @@ class AggregationStateSum : public AggregationState { bool null_; SpinMutex mutex_; - int sum_offset, null_offset; + int sum_offset_, null_offset_; }; - /** * @brief An aggregationhandle for sum. **/ class AggregationHandleSum : public AggregationConcreteHandle { public: - ~AggregationHandleSum() override { - } + ~AggregationHandleSum() override {} AggregationState* createInitialState() const override { return new AggregationStateSum(blank_state_); @@ -105,11 +107,12 @@ class AggregationHandleSum : public AggregationConcreteHandle { AggregationStateHashTableBase* createGroupByHashTable( const HashTableImplType hash_table_impl, - const std::vector<const Type*> &group_by_types, + const std::vector<const Type *> &group_by_types, const std::size_t estimated_num_groups, StorageManager *storage_manager) const override; - inline void iterateUnaryInl(AggregationStateSum *state, const TypedValue &value) const { + inline void iterateUnaryInl(AggregationStateSum *state, + const TypedValue &value) const { DCHECK(value.isPlausibleInstanceOf(argument_type_.getSignature())); if (value.isNull()) return; @@ -118,37 +121,41 @@ class AggregationHandleSum : public AggregationConcreteHandle { state->null_ = false; } - inline void iterateUnaryInlFast(const TypedValue &value, uint8_t *byte_ptr) const { + inline void iterateUnaryInlFast(const TypedValue &value, + std::uint8_t *byte_ptr) const { DCHECK(value.isPlausibleInstanceOf(argument_type_.getSignature())); if (value.isNull()) return; - TypedValue *sum_ptr = reinterpret_cast<TypedValue *>(byte_ptr + blank_state_.sum_offset); - bool *null_ptr = reinterpret_cast<bool *>(byte_ptr + blank_state_.null_offset); + TypedValue *sum_ptr = + reinterpret_cast<TypedValue *>(byte_ptr + blank_state_.sum_offset_); + bool *null_ptr = + reinterpret_cast<bool *>(byte_ptr + blank_state_.null_offset_); *sum_ptr = fast_operator_->applyToTypedValues(*sum_ptr, value); *null_ptr = false; } - inline void iterateInlFast(const std::vector<TypedValue> &arguments, uint8_t *byte_ptr) const override { - if (block_update) return; - iterateUnaryInlFast(arguments.front(), byte_ptr); + inline void updateState(const std::vector<TypedValue> &arguments, + std::uint8_t *byte_ptr) const override { + if (!block_update_) { + iterateUnaryInlFast(arguments.front(), byte_ptr); + } } - void BlockUpdate() override { - block_update = true; - } + void blockUpdate() override { block_update_ = true; } - void AllowUpdate() override { - block_update = false; - } + void allowUpdate() override { block_update_ = false; } - void initPayload(uint8_t *byte_ptr) const override { - TypedValue *sum_ptr = reinterpret_cast<TypedValue *>(byte_ptr + blank_state_.sum_offset); - bool *null_ptr = reinterpret_cast<bool *>(byte_ptr + blank_state_.null_offset); + void initPayload(std::uint8_t *byte_ptr) const override { + TypedValue *sum_ptr = + reinterpret_cast<TypedValue *>(byte_ptr + blank_state_.sum_offset_); + bool *null_ptr = + reinterpret_cast<bool *>(byte_ptr + blank_state_.null_offset_); *sum_ptr = blank_state_.sum_; *null_ptr = true; } AggregationState* accumulateColumnVectors( - const std::vector<std::unique_ptr<ColumnVector>> &column_vectors) const override; + const std::vector<std::unique_ptr<ColumnVector>> &column_vectors) + const override; #ifdef QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION AggregationState* accumulateValueAccessor( @@ -165,18 +172,21 @@ class AggregationHandleSum : public AggregationConcreteHandle { void mergeStates(const AggregationState &source, AggregationState *destination) const override; - void mergeStatesFast(const uint8_t *source, - uint8_t *destination) const override; + void mergeStatesFast(const std::uint8_t *source, + std::uint8_t *destination) const override; TypedValue finalize(const AggregationState &state) const override; - inline TypedValue finalizeHashTableEntry(const AggregationState &state) const { - return static_cast<const AggregationStateSum&>(state).sum_; + inline TypedValue finalizeHashTableEntry( + const AggregationState &state) const { + return static_cast<const AggregationStateSum &>(state).sum_; } - inline TypedValue finalizeHashTableEntryFast(const uint8_t *byte_ptr) const { - uint8_t *value_ptr = const_cast<uint8_t*>(byte_ptr); - TypedValue *sum_ptr = reinterpret_cast<TypedValue *>(value_ptr + blank_state_.sum_offset); + inline TypedValue finalizeHashTableEntryFast( + const std::uint8_t *byte_ptr) const { + std::uint8_t *value_ptr = const_cast<std::uint8_t *>(byte_ptr); + TypedValue *sum_ptr = + reinterpret_cast<TypedValue *>(value_ptr + blank_state_.sum_offset_); return *sum_ptr; } @@ -186,23 +196,26 @@ class AggregationHandleSum : public AggregationConcreteHandle { int index) const override; /** - * @brief Implementation of AggregationHandle::aggregateOnDistinctifyHashTableForSingle() + * @brief Implementation of + * AggregationHandle::aggregateOnDistinctifyHashTableForSingle() * for SUM aggregation. */ AggregationState* aggregateOnDistinctifyHashTableForSingle( - const AggregationStateHashTableBase &distinctify_hash_table) const override; + const AggregationStateHashTableBase &distinctify_hash_table) + const override; /** - * @brief Implementation of AggregationHandle::aggregateOnDistinctifyHashTableForGroupBy() + * @brief Implementation of + * AggregationHandle::aggregateOnDistinctifyHashTableForGroupBy() * for SUM aggregation. */ void aggregateOnDistinctifyHashTableForGroupBy( const AggregationStateHashTableBase &distinctify_hash_table, AggregationStateHashTableBase *aggregation_hash_table, - int index) const override; + std::size_t index) const override; - size_t getPayloadSize() const override { - return blank_state_.getPayloadSize(); + std::size_t getPayloadSize() const override { + return blank_state_.getPayloadSize(); } private: @@ -221,7 +234,7 @@ class AggregationHandleSum : public AggregationConcreteHandle { std::unique_ptr<UncheckedBinaryOperator> fast_operator_; std::unique_ptr<UncheckedBinaryOperator> merge_operator_; - bool block_update; + bool block_update_; DISALLOW_COPY_AND_ASSIGN(AggregationHandleSum); }; http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/c123bd49/expressions/aggregation/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/expressions/aggregation/CMakeLists.txt b/expressions/aggregation/CMakeLists.txt index 30f9784..e9503f7 100644 --- a/expressions/aggregation/CMakeLists.txt +++ b/expressions/aggregation/CMakeLists.txt @@ -280,45 +280,46 @@ target_link_libraries(quickstep_expressions_aggregation # Tests: # Unified executable to ammortize cost of linking. -# add_executable(AggregationHandle_tests -# "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleAvg_unittest.cpp" -# "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleCount_unittest.cpp" -# "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleMax_unittest.cpp" -# "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleMin_unittest.cpp" -# "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleSum_unittest.cpp") -# target_link_libraries(AggregationHandle_tests -# gtest -# gtest_main -# quickstep_catalog_CatalogTypedefs -# quickstep_expressions_aggregation_AggregateFunction -# quickstep_expressions_aggregation_AggregateFunctionFactory -# quickstep_expressions_aggregation_AggregationHandle -# quickstep_expressions_aggregation_AggregationHandleAvg -# quickstep_expressions_aggregation_AggregationHandleCount -# quickstep_expressions_aggregation_AggregationHandleMax -# quickstep_expressions_aggregation_AggregationHandleMin -# quickstep_expressions_aggregation_AggregationHandleSum -# quickstep_expressions_aggregation_AggregationID -# quickstep_storage_HashTableBase -# quickstep_storage_StorageManager -# quickstep_types_CharType -# quickstep_types_DateOperatorOverloads -# quickstep_types_DatetimeIntervalType -# quickstep_types_DatetimeType -# quickstep_types_DoubleType -# quickstep_types_FloatType -# quickstep_types_IntType -# quickstep_types_IntervalLit -# quickstep_types_LongType -# quickstep_types_Type -# quickstep_types_TypeFactory -# quickstep_types_TypeID -# quickstep_types_TypedValue -# quickstep_types_VarCharType -# quickstep_types_YearMonthIntervalType -# quickstep_types_containers_ColumnVector -# quickstep_types_containers_ColumnVectorsValueAccessor -# quickstep_types_operations_comparisons_Comparison -# quickstep_types_operations_comparisons_ComparisonFactory -# quickstep_types_operations_comparisons_ComparisonID) -#add_test(AggregationHandle_tests AggregationHandle_tests) +add_executable(AggregationHandle_tests + "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleAvg_unittest.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleCount_unittest.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleMax_unittest.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleMin_unittest.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/tests/AggregationHandleSum_unittest.cpp") +target_link_libraries(AggregationHandle_tests + gtest + gtest_main + quickstep_catalog_CatalogTypedefs + quickstep_expressions_aggregation_AggregateFunction + quickstep_expressions_aggregation_AggregateFunctionFactory + quickstep_expressions_aggregation_AggregationHandle + quickstep_expressions_aggregation_AggregationHandleAvg + quickstep_expressions_aggregation_AggregationHandleCount + quickstep_expressions_aggregation_AggregationHandleMax + quickstep_expressions_aggregation_AggregationHandleMin + quickstep_expressions_aggregation_AggregationHandleSum + quickstep_expressions_aggregation_AggregationID + quickstep_storage_AggregationOperationState + quickstep_storage_HashTableBase + quickstep_storage_StorageManager + quickstep_types_CharType + quickstep_types_DateOperatorOverloads + quickstep_types_DatetimeIntervalType + quickstep_types_DatetimeType + quickstep_types_DoubleType + quickstep_types_FloatType + quickstep_types_IntType + quickstep_types_IntervalLit + quickstep_types_LongType + quickstep_types_Type + quickstep_types_TypeFactory + quickstep_types_TypeID + quickstep_types_TypedValue + quickstep_types_VarCharType + quickstep_types_YearMonthIntervalType + quickstep_types_containers_ColumnVector + quickstep_types_containers_ColumnVectorsValueAccessor + quickstep_types_operations_comparisons_Comparison + quickstep_types_operations_comparisons_ComparisonFactory + quickstep_types_operations_comparisons_ComparisonID) +add_test(AggregationHandle_tests AggregationHandle_tests) http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/c123bd49/expressions/aggregation/tests/AggregationHandleAvg_unittest.cpp ---------------------------------------------------------------------- diff --git a/expressions/aggregation/tests/AggregationHandleAvg_unittest.cpp b/expressions/aggregation/tests/AggregationHandleAvg_unittest.cpp index afc02ec..79d4448 100644 --- a/expressions/aggregation/tests/AggregationHandleAvg_unittest.cpp +++ b/expressions/aggregation/tests/AggregationHandleAvg_unittest.cpp @@ -28,6 +28,8 @@ #include "expressions/aggregation/AggregationHandle.hpp" #include "expressions/aggregation/AggregationHandleAvg.hpp" #include "expressions/aggregation/AggregationID.hpp" +#include "storage/AggregationOperationState.hpp" +#include "storage/FastHashTableFactory.hpp" #include "storage/StorageManager.hpp" #include "types/CharType.hpp" #include "types/DateOperatorOverloads.hpp" @@ -53,51 +55,56 @@ namespace quickstep { -class AggregationHandleAvgTest : public::testing::Test { +class AggregationHandleAvgTest : public ::testing::Test { protected: static const int kNumSamples = 100; // Helper method that calls AggregationHandleAvg::iterateUnaryInl() to // aggregate 'value' into '*state'. void iterateHandle(AggregationState *state, const TypedValue &value) { - static_cast<const AggregationHandleAvg&>(*aggregation_handle_avg_).iterateUnaryInl( - static_cast<AggregationStateAvg*>(state), - value); + static_cast<const AggregationHandleAvg &>(*aggregation_handle_avg_) + .iterateUnaryInl(static_cast<AggregationStateAvg *>(state), value); } void initializeHandle(const Type &type) { aggregation_handle_avg_.reset( - AggregateFunctionFactory::Get(AggregationID::kAvg).createHandle( - std::vector<const Type*>(1, &type))); + AggregateFunctionFactory::Get(AggregationID::kAvg) + .createHandle(std::vector<const Type *>(1, &type))); aggregation_handle_avg_state_.reset( aggregation_handle_avg_->createInitialState()); } static bool ApplyToTypesTest(TypeID typeID) { - const Type &type = (typeID == kChar || typeID == kVarChar) ? - TypeFactory::GetType(typeID, static_cast<std::size_t>(10)) : - TypeFactory::GetType(typeID); + const Type &type = + (typeID == kChar || typeID == kVarChar) + ? TypeFactory::GetType(typeID, static_cast<std::size_t>(10)) + : TypeFactory::GetType(typeID); - return AggregateFunctionFactory::Get(AggregationID::kAvg).canApplyToTypes( - std::vector<const Type*>(1, &type)); + return AggregateFunctionFactory::Get(AggregationID::kAvg) + .canApplyToTypes(std::vector<const Type *>(1, &type)); } static bool ResultTypeForArgumentTypeTest(TypeID input_type_id, TypeID output_type_id) { - const Type *result_type - = AggregateFunctionFactory::Get(AggregationID::kAvg).resultTypeForArgumentTypes( - std::vector<const Type*>(1, &TypeFactory::GetType(input_type_id))); + const Type *result_type = + AggregateFunctionFactory::Get(AggregationID::kAvg) + .resultTypeForArgumentTypes(std::vector<const Type *>( + 1, &TypeFactory::GetType(input_type_id))); return (result_type->getTypeID() == output_type_id); } template <typename CppType> - static void CheckAvgValue( - CppType expected, - const AggregationHandle &handle, - const AggregationState &state) { + static void CheckAvgValue(CppType expected, + const AggregationHandle &handle, + const AggregationState &state) { EXPECT_EQ(expected, handle.finalize(state).getLiteral<CppType>()); } + template <typename CppType> + static void CheckAvgValue(CppType expected, const TypedValue &value) { + EXPECT_EQ(expected, value.getLiteral<CppType>()); + } + // Static templated method for set a meaningful value to data types. template <typename CppType> static void SetDataType(int value, CppType *data) { @@ -108,7 +115,9 @@ class AggregationHandleAvgTest : public::testing::Test { void checkAggregationAvgGeneric() { const GenericType &type = GenericType::Instance(true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_avg_->finalize(*aggregation_handle_avg_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_avg_->finalize(*aggregation_handle_avg_state_) + .isNull()); typename GenericType::cpptype val; typename GenericType::cpptype sum; @@ -119,15 +128,16 @@ class AggregationHandleAvgTest : public::testing::Test { if (type.getTypeID() == kInt || type.getTypeID() == kLong) { SetDataType(i - 10, &val); } else { - SetDataType(static_cast<float>(i - 10)/10, &val); + SetDataType(static_cast<float>(i - 10) / 10, &val); } iterateHandle(aggregation_handle_avg_state_.get(), type.makeValue(&val)); sum += val; } iterateHandle(aggregation_handle_avg_state_.get(), type.makeNullValue()); - CheckAvgValue<typename OutputType::cpptype>(static_cast<typename OutputType::cpptype>(sum) / kNumSamples, - *aggregation_handle_avg_, - *aggregation_handle_avg_state_); + CheckAvgValue<typename OutputType::cpptype>( + static_cast<typename OutputType::cpptype>(sum) / kNumSamples, + *aggregation_handle_avg_, + *aggregation_handle_avg_state_); // Test mergeStates(). std::unique_ptr<AggregationState> merge_state( @@ -140,7 +150,7 @@ class AggregationHandleAvgTest : public::testing::Test { if (type.getTypeID() == kInt || type.getTypeID() == kLong) { SetDataType(i - 10, &val); } else { - SetDataType(static_cast<float>(i - 10)/10, &val); + SetDataType(static_cast<float>(i - 10) / 10, &val); } iterateHandle(merge_state.get(), type.makeValue(&val)); sum += val; @@ -155,7 +165,8 @@ class AggregationHandleAvgTest : public::testing::Test { } template <typename GenericType> - ColumnVector *createColumnVectorGeneric(const Type &type, typename GenericType::cpptype *sum) { + ColumnVector* createColumnVectorGeneric(const Type &type, + typename GenericType::cpptype *sum) { NativeColumnVector *column = new NativeColumnVector(type, kNumSamples + 3); typename GenericType::cpptype val; @@ -166,12 +177,12 @@ class AggregationHandleAvgTest : public::testing::Test { if (type.getTypeID() == kInt || type.getTypeID() == kLong) { SetDataType(i - 10, &val); } else { - SetDataType(static_cast<float>(i - 10)/10, &val); + SetDataType(static_cast<float>(i - 10) / 10, &val); } column->appendTypedValue(type.makeValue(&val)); *sum += val; // One NULL in the middle. - if (i == kNumSamples/2) { + if (i == kNumSamples / 2) { column->appendTypedValue(type.makeNullValue()); } } @@ -184,12 +195,15 @@ class AggregationHandleAvgTest : public::testing::Test { void checkAggregationAvgGenericColumnVector() { const GenericType &type = GenericType::Instance(true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_avg_->finalize(*aggregation_handle_avg_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_avg_->finalize(*aggregation_handle_avg_state_) + .isNull()); typename GenericType::cpptype sum; SetDataType(0, &sum); std::vector<std::unique_ptr<ColumnVector>> column_vectors; - column_vectors.emplace_back(createColumnVectorGeneric<GenericType>(type, &sum)); + column_vectors.emplace_back( + createColumnVectorGeneric<GenericType>(type, &sum)); std::unique_ptr<AggregationState> cv_state( aggregation_handle_avg_->accumulateColumnVectors(column_vectors)); @@ -201,7 +215,8 @@ class AggregationHandleAvgTest : public::testing::Test { *aggregation_handle_avg_, *cv_state); - aggregation_handle_avg_->mergeStates(*cv_state, aggregation_handle_avg_state_.get()); + aggregation_handle_avg_->mergeStates(*cv_state, + aggregation_handle_avg_state_.get()); CheckAvgValue<typename OutputType::cpptype>( static_cast<typename OutputType::cpptype>(sum) / kNumSamples, *aggregation_handle_avg_, @@ -213,16 +228,19 @@ class AggregationHandleAvgTest : public::testing::Test { void checkAggregationAvgGenericValueAccessor() { const GenericType &type = GenericType::Instance(true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_avg_->finalize(*aggregation_handle_avg_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_avg_->finalize(*aggregation_handle_avg_state_) + .isNull()); typename GenericType::cpptype sum; SetDataType(0, &sum); - std::unique_ptr<ColumnVectorsValueAccessor> accessor(new ColumnVectorsValueAccessor()); + std::unique_ptr<ColumnVectorsValueAccessor> accessor( + new ColumnVectorsValueAccessor()); accessor->addColumn(createColumnVectorGeneric<GenericType>(type, &sum)); std::unique_ptr<AggregationState> va_state( - aggregation_handle_avg_->accumulateValueAccessor(accessor.get(), - std::vector<attribute_id>(1, 0))); + aggregation_handle_avg_->accumulateValueAccessor( + accessor.get(), std::vector<attribute_id>(1, 0))); // Test the state generated directly by accumulateValueAccessor(), and also // test after merging back. @@ -231,7 +249,8 @@ class AggregationHandleAvgTest : public::testing::Test { *aggregation_handle_avg_, *va_state); - aggregation_handle_avg_->mergeStates(*va_state, aggregation_handle_avg_state_.get()); + aggregation_handle_avg_->mergeStates(*va_state, + aggregation_handle_avg_state_.get()); CheckAvgValue<typename OutputType::cpptype>( static_cast<typename OutputType::cpptype>(sum) / kNumSamples, *aggregation_handle_avg_, @@ -255,12 +274,14 @@ void AggregationHandleAvgTest::CheckAvgValue<double>( } template <> -void AggregationHandleAvgTest::SetDataType<DatetimeIntervalLit>(int value, DatetimeIntervalLit *data) { +void AggregationHandleAvgTest::SetDataType<DatetimeIntervalLit>( + int value, DatetimeIntervalLit *data) { data->interval_ticks = value; } template <> -void AggregationHandleAvgTest::SetDataType<YearMonthIntervalLit>(int value, YearMonthIntervalLit *data) { +void AggregationHandleAvgTest::SetDataType<YearMonthIntervalLit>( + int value, YearMonthIntervalLit *data) { data->months = value; } @@ -307,11 +328,13 @@ TEST_F(AggregationHandleAvgTest, DoubleTypeColumnVectorTest) { } TEST_F(AggregationHandleAvgTest, DatetimeIntervalTypeColumnVectorTest) { - checkAggregationAvgGenericColumnVector<DatetimeIntervalType, DatetimeIntervalType>(); + checkAggregationAvgGenericColumnVector<DatetimeIntervalType, + DatetimeIntervalType>(); } TEST_F(AggregationHandleAvgTest, YearMonthIntervalTypeColumnVectorTest) { - checkAggregationAvgGenericColumnVector<YearMonthIntervalType, YearMonthIntervalType>(); + checkAggregationAvgGenericColumnVector<YearMonthIntervalType, + YearMonthIntervalType>(); } #ifdef QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -332,11 +355,13 @@ TEST_F(AggregationHandleAvgTest, DoubleTypeValueAccessorTest) { } TEST_F(AggregationHandleAvgTest, DatetimeIntervalTypeValueAccessorTest) { - checkAggregationAvgGenericValueAccessor<DatetimeIntervalType, DatetimeIntervalType>(); + checkAggregationAvgGenericValueAccessor<DatetimeIntervalType, + DatetimeIntervalType>(); } TEST_F(AggregationHandleAvgTest, YearMonthIntervalTypeValueAccessorTest) { - checkAggregationAvgGenericValueAccessor<YearMonthIntervalType, YearMonthIntervalType>(); + checkAggregationAvgGenericValueAccessor<YearMonthIntervalType, + YearMonthIntervalType>(); } #endif // QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -365,38 +390,53 @@ TEST_F(AggregationHandleAvgDeathTest, WrongTypeTest) { double double_val = 0; float float_val = 0; - iterateHandle(aggregation_handle_avg_state_.get(), int_non_null_type.makeValue(&int_val)); + iterateHandle(aggregation_handle_avg_state_.get(), + int_non_null_type.makeValue(&int_val)); - EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), long_type.makeValue(&long_val)), ""); - EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), double_type.makeValue(&double_val)), ""); - EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), float_type.makeValue(&float_val)), ""); - EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), char_type.makeValue("asdf", 5)), ""); - EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), varchar_type.makeValue("asdf", 5)), ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), + long_type.makeValue(&long_val)), + ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), + double_type.makeValue(&double_val)), + ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), + float_type.makeValue(&float_val)), + ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), + char_type.makeValue("asdf", 5)), + ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_avg_state_.get(), + varchar_type.makeValue("asdf", 5)), + ""); // Test mergeStates() with incorrectly typed handles. std::unique_ptr<AggregationHandle> aggregation_handle_avg_double( - AggregateFunctionFactory::Get(AggregationID::kAvg).createHandle( - std::vector<const Type*>(1, &double_type))); + AggregateFunctionFactory::Get(AggregationID::kAvg) + .createHandle(std::vector<const Type *>(1, &double_type))); std::unique_ptr<AggregationState> aggregation_state_avg_merge_double( aggregation_handle_avg_double->createInitialState()); - static_cast<const AggregationHandleAvg&>(*aggregation_handle_avg_double).iterateUnaryInl( - static_cast<AggregationStateAvg*>(aggregation_state_avg_merge_double.get()), - double_type.makeValue(&double_val)); - EXPECT_DEATH(aggregation_handle_avg_->mergeStates(*aggregation_state_avg_merge_double, - aggregation_handle_avg_state_.get()), - ""); + static_cast<const AggregationHandleAvg &>(*aggregation_handle_avg_double) + .iterateUnaryInl(static_cast<AggregationStateAvg *>( + aggregation_state_avg_merge_double.get()), + double_type.makeValue(&double_val)); + EXPECT_DEATH( + aggregation_handle_avg_->mergeStates(*aggregation_state_avg_merge_double, + aggregation_handle_avg_state_.get()), + ""); std::unique_ptr<AggregationHandle> aggregation_handle_avg_float( - AggregateFunctionFactory::Get(AggregationID::kAvg).createHandle( - std::vector<const Type*>(1, &float_type))); + AggregateFunctionFactory::Get(AggregationID::kAvg) + .createHandle(std::vector<const Type *>(1, &float_type))); std::unique_ptr<AggregationState> aggregation_state_avg_merge_float( aggregation_handle_avg_float->createInitialState()); - static_cast<const AggregationHandleAvg&>(*aggregation_handle_avg_float).iterateUnaryInl( - static_cast<AggregationStateAvg*>(aggregation_state_avg_merge_float.get()), - float_type.makeValue(&float_val)); - EXPECT_DEATH(aggregation_handle_avg_->mergeStates(*aggregation_state_avg_merge_float, - aggregation_handle_avg_state_.get()), - ""); + static_cast<const AggregationHandleAvg &>(*aggregation_handle_avg_float) + .iterateUnaryInl(static_cast<AggregationStateAvg *>( + aggregation_state_avg_merge_float.get()), + float_type.makeValue(&float_val)); + EXPECT_DEATH( + aggregation_handle_avg_->mergeStates(*aggregation_state_avg_merge_float, + aggregation_handle_avg_state_.get()), + ""); } #endif @@ -417,8 +457,10 @@ TEST_F(AggregationHandleAvgTest, ResultTypeForArgumentTypeTest) { EXPECT_TRUE(ResultTypeForArgumentTypeTest(kLong, kDouble)); EXPECT_TRUE(ResultTypeForArgumentTypeTest(kFloat, kDouble)); EXPECT_TRUE(ResultTypeForArgumentTypeTest(kDouble, kDouble)); - EXPECT_TRUE(ResultTypeForArgumentTypeTest(kDatetimeInterval, kDatetimeInterval)); - EXPECT_TRUE(ResultTypeForArgumentTypeTest(kYearMonthInterval, kYearMonthInterval)); + EXPECT_TRUE( + ResultTypeForArgumentTypeTest(kDatetimeInterval, kDatetimeInterval)); + EXPECT_TRUE( + ResultTypeForArgumentTypeTest(kYearMonthInterval, kYearMonthInterval)); } TEST_F(AggregationHandleAvgTest, GroupByTableMergeTestAvg) { @@ -426,25 +468,28 @@ TEST_F(AggregationHandleAvgTest, GroupByTableMergeTestAvg) { initializeHandle(long_non_null_type); storage_manager_.reset(new StorageManager("./test_avg_data")); std::unique_ptr<AggregationStateHashTableBase> source_hash_table( - aggregation_handle_avg_->createGroupByHashTable( - HashTableImplType::kSimpleScalarSeparateChaining, + AggregationStateFastHashTableFactory::CreateResizable( + HashTableImplType::kSeparateChaining, std::vector<const Type *>(1, &long_non_null_type), 10, + {aggregation_handle_avg_.get()->getPayloadSize()}, + {aggregation_handle_avg_.get()}, storage_manager_.get())); std::unique_ptr<AggregationStateHashTableBase> destination_hash_table( - aggregation_handle_avg_->createGroupByHashTable( - HashTableImplType::kSimpleScalarSeparateChaining, + AggregationStateFastHashTableFactory::CreateResizable( + HashTableImplType::kSeparateChaining, std::vector<const Type *>(1, &long_non_null_type), 10, + {aggregation_handle_avg_.get()->getPayloadSize()}, + {aggregation_handle_avg_.get()}, storage_manager_.get())); - AggregationStateHashTable<AggregationStateAvg> *destination_hash_table_derived = - static_cast<AggregationStateHashTable<AggregationStateAvg> *>( + AggregationStateFastHashTable *destination_hash_table_derived = + static_cast<AggregationStateFastHashTable *>( destination_hash_table.get()); - AggregationStateHashTable<AggregationStateAvg> *source_hash_table_derived = - static_cast<AggregationStateHashTable<AggregationStateAvg> *>( - source_hash_table.get()); + AggregationStateFastHashTable *source_hash_table_derived = + static_cast<AggregationStateFastHashTable *>(source_hash_table.get()); AggregationHandleAvg *aggregation_handle_avg_derived = static_cast<AggregationHandleAvg *>(aggregation_handle_avg_.get()); @@ -496,36 +541,56 @@ TEST_F(AggregationHandleAvgTest, GroupByTableMergeTestAvg) { exclusive_key_source_state.get(), exclusive_key_source_avg_val); // Add the key-state pairs to the hash tables. - source_hash_table_derived->putCompositeKey(common_key, - *common_key_source_state); - destination_hash_table_derived->putCompositeKey( - common_key, *common_key_destination_state); - source_hash_table_derived->putCompositeKey(exclusive_source_key, - *exclusive_key_source_state); - destination_hash_table_derived->putCompositeKey( - exclusive_destination_key, *exclusive_key_destination_state); + unsigned char buffer[100]; + buffer[0] = '\0'; + memcpy(buffer + 1, + common_key_source_state.get()->getPayloadAddress(), + aggregation_handle_avg_.get()->getPayloadSize()); + source_hash_table_derived->putCompositeKey(common_key, buffer); + + memcpy(buffer + 1, + common_key_destination_state.get()->getPayloadAddress(), + aggregation_handle_avg_.get()->getPayloadSize()); + destination_hash_table_derived->putCompositeKey(common_key, buffer); + + memcpy(buffer + 1, + exclusive_key_source_state.get()->getPayloadAddress(), + aggregation_handle_avg_.get()->getPayloadSize()); + source_hash_table_derived->putCompositeKey(exclusive_source_key, buffer); + + memcpy(buffer + 1, + exclusive_key_destination_state.get()->getPayloadAddress(), + aggregation_handle_avg_.get()->getPayloadSize()); + destination_hash_table_derived->putCompositeKey(exclusive_destination_key, + buffer); EXPECT_EQ(2u, destination_hash_table_derived->numEntries()); EXPECT_EQ(2u, source_hash_table_derived->numEntries()); - aggregation_handle_avg_->mergeGroupByHashTables(*source_hash_table, - destination_hash_table.get()); + AggregationOperationState::mergeGroupByHashTables( + source_hash_table.get(), destination_hash_table.get()); EXPECT_EQ(3u, destination_hash_table_derived->numEntries()); CheckAvgValue<double>( (common_key_destination_avg_val.getLiteral<std::int64_t>() + - common_key_source_avg_val.getLiteral<std::int64_t>()) / static_cast<double>(2), - *aggregation_handle_avg_derived, - *(destination_hash_table_derived->getSingleCompositeKey(common_key))); - CheckAvgValue<double>(exclusive_key_destination_avg_val.getLiteral<std::int64_t>(), - *aggregation_handle_avg_derived, - *(destination_hash_table_derived->getSingleCompositeKey( - exclusive_destination_key))); - CheckAvgValue<double>(exclusive_key_source_avg_val.getLiteral<std::int64_t>(), - *aggregation_handle_avg_derived, - *(source_hash_table_derived->getSingleCompositeKey( - exclusive_source_key))); + common_key_source_avg_val.getLiteral<std::int64_t>()) / + static_cast<double>(2), + aggregation_handle_avg_derived->finalizeHashTableEntryFast( + destination_hash_table_derived->getSingleCompositeKey(common_key) + + 1)); + CheckAvgValue<double>( + exclusive_key_destination_avg_val.getLiteral<std::int64_t>(), + aggregation_handle_avg_derived->finalizeHashTableEntryFast( + destination_hash_table_derived->getSingleCompositeKey( + exclusive_destination_key) + + 1)); + CheckAvgValue<double>( + exclusive_key_source_avg_val.getLiteral<std::int64_t>(), + aggregation_handle_avg_derived->finalizeHashTableEntryFast( + source_hash_table_derived->getSingleCompositeKey( + exclusive_source_key) + + 1)); } } // namespace quickstep http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/c123bd49/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp ---------------------------------------------------------------------- diff --git a/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp b/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp index 6565a41..78bd249 100644 --- a/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp +++ b/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp @@ -29,6 +29,8 @@ #include "expressions/aggregation/AggregationHandle.hpp" #include "expressions/aggregation/AggregationHandleCount.hpp" #include "expressions/aggregation/AggregationID.hpp" +#include "storage/AggregationOperationState.hpp" +#include "storage/FastHashTableFactory.hpp" #include "storage/StorageManager.hpp" #include "types/CharType.hpp" #include "types/DoubleType.hpp" @@ -50,85 +52,94 @@ namespace quickstep { -class AggregationHandleCountTest : public::testing::Test { +class AggregationHandleCountTest : public ::testing::Test { protected: const Type &dummy_type = TypeFactory::GetType(kInt); void iterateHandleNullary(AggregationState *state) { - static_cast<const AggregationHandleCount<true, false>&>( - *aggregation_handle_count_).iterateNullaryInl( - static_cast<AggregationStateCount*>(state)); + static_cast<const AggregationHandleCount<true, false> &>( + *aggregation_handle_count_) + .iterateNullaryInl(static_cast<AggregationStateCount *>(state)); } // Helper method that calls AggregationHandleCount::iterateUnaryInl() to // aggregate 'value' into '*state'. void iterateHandle(AggregationState *state, const TypedValue &value) { - static_cast<const AggregationHandleCount<false, true>&>( - *aggregation_handle_count_).iterateUnaryInl( - static_cast<AggregationStateCount*>(state), - value); + static_cast<const AggregationHandleCount<false, true> &>( + *aggregation_handle_count_) + .iterateUnaryInl(static_cast<AggregationStateCount *>(state), value); } void initializeHandle(const Type *argument_type) { if (argument_type == nullptr) { aggregation_handle_count_.reset( - AggregateFunctionFactory::Get(AggregationID::kCount).createHandle( - std::vector<const Type*>())); + AggregateFunctionFactory::Get(AggregationID::kCount) + .createHandle(std::vector<const Type *>())); } else { aggregation_handle_count_.reset( - AggregateFunctionFactory::Get(AggregationID::kCount).createHandle( - std::vector<const Type*>(1, argument_type))); + AggregateFunctionFactory::Get(AggregationID::kCount) + .createHandle(std::vector<const Type *>(1, argument_type))); } aggregation_handle_count_state_.reset( aggregation_handle_count_->createInitialState()); } static bool ApplyToTypesTest(TypeID typeID) { - const Type &type = (typeID == kChar || typeID == kVarChar) ? - TypeFactory::GetType(typeID, static_cast<std::size_t>(10)) : - TypeFactory::GetType(typeID); + const Type &type = + (typeID == kChar || typeID == kVarChar) + ? TypeFactory::GetType(typeID, static_cast<std::size_t>(10)) + : TypeFactory::GetType(typeID); - return AggregateFunctionFactory::Get(AggregationID::kCount).canApplyToTypes( - std::vector<const Type*>(1, &type)); + return AggregateFunctionFactory::Get(AggregationID::kCount) + .canApplyToTypes(std::vector<const Type *>(1, &type)); } static bool ResultTypeForArgumentTypeTest(TypeID input_type_id, TypeID output_type_id) { - const Type *result_type - = AggregateFunctionFactory::Get(AggregationID::kCount).resultTypeForArgumentTypes( - std::vector<const Type*>(1, &TypeFactory::GetType(input_type_id))); + const Type *result_type = + AggregateFunctionFactory::Get(AggregationID::kCount) + .resultTypeForArgumentTypes(std::vector<const Type *>( + 1, &TypeFactory::GetType(input_type_id))); return (result_type->getTypeID() == output_type_id); } - static void CheckCountValue( - std::int64_t expected, - const AggregationHandle &handle, - const AggregationState &state) { + static void CheckCountValue(std::int64_t expected, + const AggregationHandle &handle, + const AggregationState &state) { EXPECT_EQ(expected, handle.finalize(state).getLiteral<std::int64_t>()); } + static void CheckCountValue(std::int64_t expected, const TypedValue &value) { + EXPECT_EQ(expected, value.getLiteral<std::int64_t>()); + } + void checkAggregationCountNullary(int test_count) { initializeHandle(nullptr); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); for (int i = 0; i < test_count; ++i) { iterateHandleNullary(aggregation_handle_count_state_.get()); } - CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue(test_count, + *aggregation_handle_count_, + *aggregation_handle_count_state_); // Test mergeStates. std::unique_ptr<AggregationState> merge_state( aggregation_handle_count_->createInitialState()); - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); for (int i = 0; i < test_count; ++i) { iterateHandleNullary(merge_state.get()); } - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); - CheckCountValue(2 * test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); + CheckCountValue(2 * test_count, + *aggregation_handle_count_, + *aggregation_handle_count_state_); } void checkAggregationCountNullaryAccumulate(int test_count) { @@ -139,12 +150,10 @@ class AggregationHandleCountTest : public::testing::Test { // Test the state generated directly by accumulateNullary(), and also test // after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *accumulated_state); + CheckCountValue(test_count, *aggregation_handle_count_, *accumulated_state); - aggregation_handle_count_->mergeStates(*accumulated_state, - aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *accumulated_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -154,24 +163,27 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountNumeric(int test_count) { const NumericType &type = NumericType::Instance(true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); typename NumericType::cpptype val = 0; int count = 0; iterateHandle(aggregation_handle_count_state_.get(), type.makeNullValue()); for (int i = 0; i < test_count; ++i) { - iterateHandle(aggregation_handle_count_state_.get(), type.makeValue(&val)); + iterateHandle(aggregation_handle_count_state_.get(), + type.makeValue(&val)); ++count; } iterateHandle(aggregation_handle_count_state_.get(), type.makeNullValue()); - CheckCountValue(count, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + count, *aggregation_handle_count_, *aggregation_handle_count_state_); // Test mergeStates. std::unique_ptr<AggregationState> merge_state( aggregation_handle_count_->createInitialState()); - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); iterateHandle(merge_state.get(), type.makeNullValue()); for (int i = 0; i < test_count; ++i) { @@ -180,13 +192,14 @@ class AggregationHandleCountTest : public::testing::Test { } iterateHandle(merge_state.get(), type.makeNullValue()); - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); - CheckCountValue(count, *aggregation_handle_count_, *aggregation_handle_count_state_); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); + CheckCountValue( + count, *aggregation_handle_count_, *aggregation_handle_count_state_); } template <typename NumericType> - ColumnVector *createColumnVectorNumeric(const Type &type, int test_count) { + ColumnVector* createColumnVectorNumeric(const Type &type, int test_count) { NativeColumnVector *column = new NativeColumnVector(type, test_count + 3); typename NumericType::cpptype val = 0; @@ -194,7 +207,7 @@ class AggregationHandleCountTest : public::testing::Test { for (int i = 0; i < test_count; ++i) { column->appendTypedValue(type.makeValue(&val)); // One NULL in the middle. - if (i == test_count/2) { + if (i == test_count / 2) { column->appendTypedValue(type.makeNullValue()); } } @@ -206,21 +219,22 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountNumericColumnVector(int test_count) { const NumericType &type = NumericType::Instance(true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); std::vector<std::unique_ptr<ColumnVector>> column_vectors; - column_vectors.emplace_back(createColumnVectorNumeric<NumericType>(type, test_count)); + column_vectors.emplace_back( + createColumnVectorNumeric<NumericType>(type, test_count)); std::unique_ptr<AggregationState> cv_state( aggregation_handle_count_->accumulateColumnVectors(column_vectors)); // Test the state generated directly by accumulateColumnVectors(), and also // test after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *cv_state); + CheckCountValue(test_count, *aggregation_handle_count_, *cv_state); - aggregation_handle_count_->mergeStates(*cv_state, aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *cv_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -231,22 +245,24 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountNumericValueAccessor(int test_count) { const NumericType &type = NumericType::Instance(true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); - std::unique_ptr<ColumnVectorsValueAccessor> accessor(new ColumnVectorsValueAccessor()); - accessor->addColumn(createColumnVectorNumeric<NumericType>(type, test_count)); + std::unique_ptr<ColumnVectorsValueAccessor> accessor( + new ColumnVectorsValueAccessor()); + accessor->addColumn( + createColumnVectorNumeric<NumericType>(type, test_count)); std::unique_ptr<AggregationState> va_state( - aggregation_handle_count_->accumulateValueAccessor(accessor.get(), - std::vector<attribute_id>(1, 0))); + aggregation_handle_count_->accumulateValueAccessor( + accessor.get(), std::vector<attribute_id>(1, 0))); // Test the state generated directly by accumulateValueAccessor(), and also // test after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *va_state); + CheckCountValue(test_count, *aggregation_handle_count_, *va_state); - aggregation_handle_count_->mergeStates(*va_state, aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *va_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -257,7 +273,8 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountString(int test_count) { const StringType &type = StringType::Instance(10, true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); std::string string_literal = "test_str"; int count = 0; @@ -269,7 +286,8 @@ class AggregationHandleCountTest : public::testing::Test { ++count; } iterateHandle(aggregation_handle_count_state_.get(), type.makeNullValue()); - CheckCountValue(count, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + count, *aggregation_handle_count_, *aggregation_handle_count_state_); // Test mergeStates(). std::unique_ptr<AggregationState> merge_state( @@ -277,18 +295,20 @@ class AggregationHandleCountTest : public::testing::Test { iterateHandle(merge_state.get(), type.makeNullValue()); for (int i = 0; i < test_count; ++i) { - iterateHandle(merge_state.get(), type.makeValue(string_literal.c_str(), 10)); + iterateHandle(merge_state.get(), + type.makeValue(string_literal.c_str(), 10)); ++count; } iterateHandle(merge_state.get(), type.makeNullValue()); - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); - CheckCountValue(count, *aggregation_handle_count_, *aggregation_handle_count_state_); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); + CheckCountValue( + count, *aggregation_handle_count_, *aggregation_handle_count_state_); } template <typename ColumnVectorType> - ColumnVector *createColumnVectorString(const Type &type, int test_count) { + ColumnVector* createColumnVectorString(const Type &type, int test_count) { ColumnVectorType *column = new ColumnVectorType(type, test_count + 3); std::string string_literal = "test_str"; @@ -296,7 +316,7 @@ class AggregationHandleCountTest : public::testing::Test { for (int i = 0; i < test_count; ++i) { column->appendTypedValue(type.makeValue(string_literal.c_str(), 10)); // One NULL in the middle. - if (i == test_count/2) { + if (i == test_count / 2) { column->appendTypedValue(type.makeNullValue()); } } @@ -309,21 +329,22 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountStringColumnVector(int test_count) { const StringType &type = StringType::Instance(10, true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); std::vector<std::unique_ptr<ColumnVector>> column_vectors; - column_vectors.emplace_back(createColumnVectorString<ColumnVectorType>(type, test_count)); + column_vectors.emplace_back( + createColumnVectorString<ColumnVectorType>(type, test_count)); std::unique_ptr<AggregationState> cv_state( aggregation_handle_count_->accumulateColumnVectors(column_vectors)); // Test the state generated directly by accumulateColumnVectors(), and also // test after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *cv_state); + CheckCountValue(test_count, *aggregation_handle_count_, *cv_state); - aggregation_handle_count_->mergeStates(*cv_state, aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *cv_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -334,22 +355,24 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountStringValueAccessor(int test_count) { const StringType &type = StringType::Instance(10, true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); - std::unique_ptr<ColumnVectorsValueAccessor> accessor(new ColumnVectorsValueAccessor()); - accessor->addColumn(createColumnVectorString<ColumnVectorType>(type, test_count)); + std::unique_ptr<ColumnVectorsValueAccessor> accessor( + new ColumnVectorsValueAccessor()); + accessor->addColumn( + createColumnVectorString<ColumnVectorType>(type, test_count)); std::unique_ptr<AggregationState> va_state( - aggregation_handle_count_->accumulateValueAccessor(accessor.get(), - std::vector<attribute_id>(1, 0))); + aggregation_handle_count_->accumulateValueAccessor( + accessor.get(), std::vector<attribute_id>(1, 0))); // Test the state generated directly by accumulateValueAccessor(), and also // test after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *va_state); + CheckCountValue(test_count, *aggregation_handle_count_, *va_state); - aggregation_handle_count_->mergeStates(*va_state, aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *va_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -364,13 +387,12 @@ class AggregationHandleCountTest : public::testing::Test { typedef AggregationHandleCountTest AggregationHandleCountDeathTest; TEST_F(AggregationHandleCountTest, CountStarTest) { - checkAggregationCountNullary(0), - checkAggregationCountNullary(10000); + checkAggregationCountNullary(0), checkAggregationCountNullary(10000); } TEST_F(AggregationHandleCountTest, CountStarAccumulateTest) { checkAggregationCountNullaryAccumulate(0), - checkAggregationCountNullaryAccumulate(10000); + checkAggregationCountNullaryAccumulate(10000); } TEST_F(AggregationHandleCountTest, IntTypeTest) { @@ -430,7 +452,8 @@ TEST_F(AggregationHandleCountTest, CharTypeColumnVectorTest) { TEST_F(AggregationHandleCountTest, VarCharTypeColumnVectorTest) { checkAggregationCountStringColumnVector<VarCharType, IndirectColumnVector>(0); - checkAggregationCountStringColumnVector<VarCharType, IndirectColumnVector>(10000); + checkAggregationCountStringColumnVector<VarCharType, IndirectColumnVector>( + 10000); } #ifdef QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -460,8 +483,10 @@ TEST_F(AggregationHandleCountTest, CharTypeValueAccessorTest) { } TEST_F(AggregationHandleCountTest, VarCharTypeValueAccessorTest) { - checkAggregationCountStringValueAccessor<VarCharType, IndirectColumnVector>(0); - checkAggregationCountStringValueAccessor<VarCharType, IndirectColumnVector>(10000); + checkAggregationCountStringValueAccessor<VarCharType, IndirectColumnVector>( + 0); + checkAggregationCountStringValueAccessor<VarCharType, IndirectColumnVector>( + 10000); } #endif // QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -486,25 +511,28 @@ TEST_F(AggregationHandleCountTest, GroupByTableMergeTestCount) { initializeHandle(&long_non_null_type); storage_manager_.reset(new StorageManager("./test_count_data")); std::unique_ptr<AggregationStateHashTableBase> source_hash_table( - aggregation_handle_count_->createGroupByHashTable( - HashTableImplType::kSimpleScalarSeparateChaining, + AggregationStateFastHashTableFactory::CreateResizable( + HashTableImplType::kSeparateChaining, std::vector<const Type *>(1, &long_non_null_type), 10, + {aggregation_handle_count_.get()->getPayloadSize()}, + {aggregation_handle_count_.get()}, storage_manager_.get())); std::unique_ptr<AggregationStateHashTableBase> destination_hash_table( - aggregation_handle_count_->createGroupByHashTable( - HashTableImplType::kSimpleScalarSeparateChaining, + AggregationStateFastHashTableFactory::CreateResizable( + HashTableImplType::kSeparateChaining, std::vector<const Type *>(1, &long_non_null_type), 10, + {aggregation_handle_count_.get()->getPayloadSize()}, + {aggregation_handle_count_.get()}, storage_manager_.get())); - AggregationStateHashTable<AggregationStateCount> *destination_hash_table_derived = - static_cast<AggregationStateHashTable<AggregationStateCount> *>( + AggregationStateFastHashTable *destination_hash_table_derived = + static_cast<AggregationStateFastHashTable *>( destination_hash_table.get()); - AggregationStateHashTable<AggregationStateCount> *source_hash_table_derived = - static_cast<AggregationStateHashTable<AggregationStateCount> *>( - source_hash_table.get()); + AggregationStateFastHashTable *source_hash_table_derived = + static_cast<AggregationStateFastHashTable *>(source_hash_table.get()); // TODO(harshad) - Use TemplateUtil::CreateBoolInstantiatedInstance to // generate all the combinations of the bool template arguments and test them. @@ -530,7 +558,8 @@ TEST_F(AggregationHandleCountTest, GroupByTableMergeTestCount) { TypedValue exclusive_key_source_count_val(exclusive_key_source_count); const std::int64_t exclusive_key_destination_count = 1; - TypedValue exclusive_key_destination_count_val(exclusive_key_destination_count); + TypedValue exclusive_key_destination_count_val( + exclusive_key_destination_count); std::unique_ptr<AggregationStateCount> common_key_source_state( static_cast<AggregationStateCount *>( @@ -546,62 +575,86 @@ TEST_F(AggregationHandleCountTest, GroupByTableMergeTestCount) { aggregation_handle_count_->createInitialState())); // Create count value states for keys. - aggregation_handle_count_derived->iterateUnaryInl(common_key_source_state.get(), - common_key_source_count_val); - std::int64_t actual_val = aggregation_handle_count_->finalize(*common_key_source_state) - .getLiteral<std::int64_t>(); + aggregation_handle_count_derived->iterateUnaryInl( + common_key_source_state.get(), common_key_source_count_val); + std::int64_t actual_val = + aggregation_handle_count_->finalize(*common_key_source_state) + .getLiteral<std::int64_t>(); EXPECT_EQ(common_key_source_count_val.getLiteral<std::int64_t>(), actual_val); aggregation_handle_count_derived->iterateUnaryInl( common_key_destination_state.get(), common_key_destination_count_val); - actual_val = aggregation_handle_count_->finalize(*common_key_destination_state) - .getLiteral<std::int64_t>(); - EXPECT_EQ(common_key_destination_count_val.getLiteral<std::int64_t>(), actual_val); + actual_val = + aggregation_handle_count_->finalize(*common_key_destination_state) + .getLiteral<std::int64_t>(); + EXPECT_EQ(common_key_destination_count_val.getLiteral<std::int64_t>(), + actual_val); aggregation_handle_count_derived->iterateUnaryInl( - exclusive_key_destination_state.get(), exclusive_key_destination_count_val); + exclusive_key_destination_state.get(), + exclusive_key_destination_count_val); actual_val = aggregation_handle_count_->finalize(*exclusive_key_destination_state) .getLiteral<std::int64_t>(); - EXPECT_EQ(exclusive_key_destination_count_val.getLiteral<std::int64_t>(), actual_val); + EXPECT_EQ(exclusive_key_destination_count_val.getLiteral<std::int64_t>(), + actual_val); aggregation_handle_count_derived->iterateUnaryInl( exclusive_key_source_state.get(), exclusive_key_source_count_val); actual_val = aggregation_handle_count_->finalize(*exclusive_key_source_state) .getLiteral<std::int64_t>(); - EXPECT_EQ(exclusive_key_source_count_val.getLiteral<std::int64_t>(), actual_val); + EXPECT_EQ(exclusive_key_source_count_val.getLiteral<std::int64_t>(), + actual_val); // Add the key-state pairs to the hash tables. - source_hash_table_derived->putCompositeKey(common_key, - *common_key_source_state); - destination_hash_table_derived->putCompositeKey( - common_key, *common_key_destination_state); - source_hash_table_derived->putCompositeKey(exclusive_source_key, - *exclusive_key_source_state); - destination_hash_table_derived->putCompositeKey( - exclusive_destination_key, *exclusive_key_destination_state); + unsigned char buffer[100]; + buffer[0] = '\0'; + memcpy(buffer + 1, + common_key_source_state.get()->getPayloadAddress(), + aggregation_handle_count_.get()->getPayloadSize()); + source_hash_table_derived->putCompositeKey(common_key, buffer); + + memcpy(buffer + 1, + common_key_destination_state.get()->getPayloadAddress(), + aggregation_handle_count_.get()->getPayloadSize()); + destination_hash_table_derived->putCompositeKey(common_key, buffer); + + memcpy(buffer + 1, + exclusive_key_source_state.get()->getPayloadAddress(), + aggregation_handle_count_.get()->getPayloadSize()); + source_hash_table_derived->putCompositeKey(exclusive_source_key, buffer); + + memcpy(buffer + 1, + exclusive_key_destination_state.get()->getPayloadAddress(), + aggregation_handle_count_.get()->getPayloadSize()); + destination_hash_table_derived->putCompositeKey(exclusive_destination_key, + buffer); EXPECT_EQ(2u, destination_hash_table_derived->numEntries()); EXPECT_EQ(2u, source_hash_table_derived->numEntries()); - aggregation_handle_count_->mergeGroupByHashTables(*source_hash_table, - destination_hash_table.get()); + AggregationOperationState::mergeGroupByHashTables( + source_hash_table.get(), destination_hash_table.get()); EXPECT_EQ(3u, destination_hash_table_derived->numEntries()); CheckCountValue( common_key_destination_count_val.getLiteral<std::int64_t>() + common_key_source_count_val.getLiteral<std::int64_t>(), - *aggregation_handle_count_derived, - *(destination_hash_table_derived->getSingleCompositeKey(common_key))); - CheckCountValue(exclusive_key_destination_count_val.getLiteral<std::int64_t>(), - *aggregation_handle_count_derived, - *(destination_hash_table_derived->getSingleCompositeKey( - exclusive_destination_key))); + aggregation_handle_count_derived->finalizeHashTableEntryFast( + destination_hash_table_derived->getSingleCompositeKey(common_key) + + 1)); + CheckCountValue( + exclusive_key_destination_count_val.getLiteral<std::int64_t>(), + aggregation_handle_count_derived->finalizeHashTableEntryFast( + destination_hash_table_derived->getSingleCompositeKey( + exclusive_destination_key) + + 1)); CheckCountValue(exclusive_key_source_count_val.getLiteral<std::int64_t>(), - *aggregation_handle_count_derived, - *(source_hash_table_derived->getSingleCompositeKey( - exclusive_source_key))); + aggregation_handle_count_derived->finalizeHashTableEntryFast( + source_hash_table_derived->getSingleCompositeKey( + exclusive_source_key) + + 1)); } } // namespace quickstep