http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/ac3512ce/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp ---------------------------------------------------------------------- diff --git a/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp b/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp index 6565a41..78bd249 100644 --- a/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp +++ b/expressions/aggregation/tests/AggregationHandleCount_unittest.cpp @@ -29,6 +29,8 @@ #include "expressions/aggregation/AggregationHandle.hpp" #include "expressions/aggregation/AggregationHandleCount.hpp" #include "expressions/aggregation/AggregationID.hpp" +#include "storage/AggregationOperationState.hpp" +#include "storage/FastHashTableFactory.hpp" #include "storage/StorageManager.hpp" #include "types/CharType.hpp" #include "types/DoubleType.hpp" @@ -50,85 +52,94 @@ namespace quickstep { -class AggregationHandleCountTest : public::testing::Test { +class AggregationHandleCountTest : public ::testing::Test { protected: const Type &dummy_type = TypeFactory::GetType(kInt); void iterateHandleNullary(AggregationState *state) { - static_cast<const AggregationHandleCount<true, false>&>( - *aggregation_handle_count_).iterateNullaryInl( - static_cast<AggregationStateCount*>(state)); + static_cast<const AggregationHandleCount<true, false> &>( + *aggregation_handle_count_) + .iterateNullaryInl(static_cast<AggregationStateCount *>(state)); } // Helper method that calls AggregationHandleCount::iterateUnaryInl() to // aggregate 'value' into '*state'. void iterateHandle(AggregationState *state, const TypedValue &value) { - static_cast<const AggregationHandleCount<false, true>&>( - *aggregation_handle_count_).iterateUnaryInl( - static_cast<AggregationStateCount*>(state), - value); + static_cast<const AggregationHandleCount<false, true> &>( + *aggregation_handle_count_) + .iterateUnaryInl(static_cast<AggregationStateCount *>(state), value); } void initializeHandle(const Type *argument_type) { if (argument_type == nullptr) { aggregation_handle_count_.reset( - AggregateFunctionFactory::Get(AggregationID::kCount).createHandle( - std::vector<const Type*>())); + AggregateFunctionFactory::Get(AggregationID::kCount) + .createHandle(std::vector<const Type *>())); } else { aggregation_handle_count_.reset( - AggregateFunctionFactory::Get(AggregationID::kCount).createHandle( - std::vector<const Type*>(1, argument_type))); + AggregateFunctionFactory::Get(AggregationID::kCount) + .createHandle(std::vector<const Type *>(1, argument_type))); } aggregation_handle_count_state_.reset( aggregation_handle_count_->createInitialState()); } static bool ApplyToTypesTest(TypeID typeID) { - const Type &type = (typeID == kChar || typeID == kVarChar) ? - TypeFactory::GetType(typeID, static_cast<std::size_t>(10)) : - TypeFactory::GetType(typeID); + const Type &type = + (typeID == kChar || typeID == kVarChar) + ? TypeFactory::GetType(typeID, static_cast<std::size_t>(10)) + : TypeFactory::GetType(typeID); - return AggregateFunctionFactory::Get(AggregationID::kCount).canApplyToTypes( - std::vector<const Type*>(1, &type)); + return AggregateFunctionFactory::Get(AggregationID::kCount) + .canApplyToTypes(std::vector<const Type *>(1, &type)); } static bool ResultTypeForArgumentTypeTest(TypeID input_type_id, TypeID output_type_id) { - const Type *result_type - = AggregateFunctionFactory::Get(AggregationID::kCount).resultTypeForArgumentTypes( - std::vector<const Type*>(1, &TypeFactory::GetType(input_type_id))); + const Type *result_type = + AggregateFunctionFactory::Get(AggregationID::kCount) + .resultTypeForArgumentTypes(std::vector<const Type *>( + 1, &TypeFactory::GetType(input_type_id))); return (result_type->getTypeID() == output_type_id); } - static void CheckCountValue( - std::int64_t expected, - const AggregationHandle &handle, - const AggregationState &state) { + static void CheckCountValue(std::int64_t expected, + const AggregationHandle &handle, + const AggregationState &state) { EXPECT_EQ(expected, handle.finalize(state).getLiteral<std::int64_t>()); } + static void CheckCountValue(std::int64_t expected, const TypedValue &value) { + EXPECT_EQ(expected, value.getLiteral<std::int64_t>()); + } + void checkAggregationCountNullary(int test_count) { initializeHandle(nullptr); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); for (int i = 0; i < test_count; ++i) { iterateHandleNullary(aggregation_handle_count_state_.get()); } - CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue(test_count, + *aggregation_handle_count_, + *aggregation_handle_count_state_); // Test mergeStates. std::unique_ptr<AggregationState> merge_state( aggregation_handle_count_->createInitialState()); - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); for (int i = 0; i < test_count; ++i) { iterateHandleNullary(merge_state.get()); } - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); - CheckCountValue(2 * test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); + CheckCountValue(2 * test_count, + *aggregation_handle_count_, + *aggregation_handle_count_state_); } void checkAggregationCountNullaryAccumulate(int test_count) { @@ -139,12 +150,10 @@ class AggregationHandleCountTest : public::testing::Test { // Test the state generated directly by accumulateNullary(), and also test // after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *accumulated_state); + CheckCountValue(test_count, *aggregation_handle_count_, *accumulated_state); - aggregation_handle_count_->mergeStates(*accumulated_state, - aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *accumulated_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -154,24 +163,27 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountNumeric(int test_count) { const NumericType &type = NumericType::Instance(true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); typename NumericType::cpptype val = 0; int count = 0; iterateHandle(aggregation_handle_count_state_.get(), type.makeNullValue()); for (int i = 0; i < test_count; ++i) { - iterateHandle(aggregation_handle_count_state_.get(), type.makeValue(&val)); + iterateHandle(aggregation_handle_count_state_.get(), + type.makeValue(&val)); ++count; } iterateHandle(aggregation_handle_count_state_.get(), type.makeNullValue()); - CheckCountValue(count, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + count, *aggregation_handle_count_, *aggregation_handle_count_state_); // Test mergeStates. std::unique_ptr<AggregationState> merge_state( aggregation_handle_count_->createInitialState()); - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); iterateHandle(merge_state.get(), type.makeNullValue()); for (int i = 0; i < test_count; ++i) { @@ -180,13 +192,14 @@ class AggregationHandleCountTest : public::testing::Test { } iterateHandle(merge_state.get(), type.makeNullValue()); - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); - CheckCountValue(count, *aggregation_handle_count_, *aggregation_handle_count_state_); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); + CheckCountValue( + count, *aggregation_handle_count_, *aggregation_handle_count_state_); } template <typename NumericType> - ColumnVector *createColumnVectorNumeric(const Type &type, int test_count) { + ColumnVector* createColumnVectorNumeric(const Type &type, int test_count) { NativeColumnVector *column = new NativeColumnVector(type, test_count + 3); typename NumericType::cpptype val = 0; @@ -194,7 +207,7 @@ class AggregationHandleCountTest : public::testing::Test { for (int i = 0; i < test_count; ++i) { column->appendTypedValue(type.makeValue(&val)); // One NULL in the middle. - if (i == test_count/2) { + if (i == test_count / 2) { column->appendTypedValue(type.makeNullValue()); } } @@ -206,21 +219,22 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountNumericColumnVector(int test_count) { const NumericType &type = NumericType::Instance(true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); std::vector<std::unique_ptr<ColumnVector>> column_vectors; - column_vectors.emplace_back(createColumnVectorNumeric<NumericType>(type, test_count)); + column_vectors.emplace_back( + createColumnVectorNumeric<NumericType>(type, test_count)); std::unique_ptr<AggregationState> cv_state( aggregation_handle_count_->accumulateColumnVectors(column_vectors)); // Test the state generated directly by accumulateColumnVectors(), and also // test after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *cv_state); + CheckCountValue(test_count, *aggregation_handle_count_, *cv_state); - aggregation_handle_count_->mergeStates(*cv_state, aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *cv_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -231,22 +245,24 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountNumericValueAccessor(int test_count) { const NumericType &type = NumericType::Instance(true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); - std::unique_ptr<ColumnVectorsValueAccessor> accessor(new ColumnVectorsValueAccessor()); - accessor->addColumn(createColumnVectorNumeric<NumericType>(type, test_count)); + std::unique_ptr<ColumnVectorsValueAccessor> accessor( + new ColumnVectorsValueAccessor()); + accessor->addColumn( + createColumnVectorNumeric<NumericType>(type, test_count)); std::unique_ptr<AggregationState> va_state( - aggregation_handle_count_->accumulateValueAccessor(accessor.get(), - std::vector<attribute_id>(1, 0))); + aggregation_handle_count_->accumulateValueAccessor( + accessor.get(), std::vector<attribute_id>(1, 0))); // Test the state generated directly by accumulateValueAccessor(), and also // test after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *va_state); + CheckCountValue(test_count, *aggregation_handle_count_, *va_state); - aggregation_handle_count_->mergeStates(*va_state, aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *va_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -257,7 +273,8 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountString(int test_count) { const StringType &type = StringType::Instance(10, true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); std::string string_literal = "test_str"; int count = 0; @@ -269,7 +286,8 @@ class AggregationHandleCountTest : public::testing::Test { ++count; } iterateHandle(aggregation_handle_count_state_.get(), type.makeNullValue()); - CheckCountValue(count, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + count, *aggregation_handle_count_, *aggregation_handle_count_state_); // Test mergeStates(). std::unique_ptr<AggregationState> merge_state( @@ -277,18 +295,20 @@ class AggregationHandleCountTest : public::testing::Test { iterateHandle(merge_state.get(), type.makeNullValue()); for (int i = 0; i < test_count; ++i) { - iterateHandle(merge_state.get(), type.makeValue(string_literal.c_str(), 10)); + iterateHandle(merge_state.get(), + type.makeValue(string_literal.c_str(), 10)); ++count; } iterateHandle(merge_state.get(), type.makeNullValue()); - aggregation_handle_count_->mergeStates(*merge_state, - aggregation_handle_count_state_.get()); - CheckCountValue(count, *aggregation_handle_count_, *aggregation_handle_count_state_); + aggregation_handle_count_->mergeStates( + *merge_state, aggregation_handle_count_state_.get()); + CheckCountValue( + count, *aggregation_handle_count_, *aggregation_handle_count_state_); } template <typename ColumnVectorType> - ColumnVector *createColumnVectorString(const Type &type, int test_count) { + ColumnVector* createColumnVectorString(const Type &type, int test_count) { ColumnVectorType *column = new ColumnVectorType(type, test_count + 3); std::string string_literal = "test_str"; @@ -296,7 +316,7 @@ class AggregationHandleCountTest : public::testing::Test { for (int i = 0; i < test_count; ++i) { column->appendTypedValue(type.makeValue(string_literal.c_str(), 10)); // One NULL in the middle. - if (i == test_count/2) { + if (i == test_count / 2) { column->appendTypedValue(type.makeNullValue()); } } @@ -309,21 +329,22 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountStringColumnVector(int test_count) { const StringType &type = StringType::Instance(10, true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); std::vector<std::unique_ptr<ColumnVector>> column_vectors; - column_vectors.emplace_back(createColumnVectorString<ColumnVectorType>(type, test_count)); + column_vectors.emplace_back( + createColumnVectorString<ColumnVectorType>(type, test_count)); std::unique_ptr<AggregationState> cv_state( aggregation_handle_count_->accumulateColumnVectors(column_vectors)); // Test the state generated directly by accumulateColumnVectors(), and also // test after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *cv_state); + CheckCountValue(test_count, *aggregation_handle_count_, *cv_state); - aggregation_handle_count_->mergeStates(*cv_state, aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *cv_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -334,22 +355,24 @@ class AggregationHandleCountTest : public::testing::Test { void checkAggregationCountStringValueAccessor(int test_count) { const StringType &type = StringType::Instance(10, true); initializeHandle(&type); - CheckCountValue(0, *aggregation_handle_count_, *aggregation_handle_count_state_); + CheckCountValue( + 0, *aggregation_handle_count_, *aggregation_handle_count_state_); - std::unique_ptr<ColumnVectorsValueAccessor> accessor(new ColumnVectorsValueAccessor()); - accessor->addColumn(createColumnVectorString<ColumnVectorType>(type, test_count)); + std::unique_ptr<ColumnVectorsValueAccessor> accessor( + new ColumnVectorsValueAccessor()); + accessor->addColumn( + createColumnVectorString<ColumnVectorType>(type, test_count)); std::unique_ptr<AggregationState> va_state( - aggregation_handle_count_->accumulateValueAccessor(accessor.get(), - std::vector<attribute_id>(1, 0))); + aggregation_handle_count_->accumulateValueAccessor( + accessor.get(), std::vector<attribute_id>(1, 0))); // Test the state generated directly by accumulateValueAccessor(), and also // test after merging back. - CheckCountValue(test_count, - *aggregation_handle_count_, - *va_state); + CheckCountValue(test_count, *aggregation_handle_count_, *va_state); - aggregation_handle_count_->mergeStates(*va_state, aggregation_handle_count_state_.get()); + aggregation_handle_count_->mergeStates( + *va_state, aggregation_handle_count_state_.get()); CheckCountValue(test_count, *aggregation_handle_count_, *aggregation_handle_count_state_); @@ -364,13 +387,12 @@ class AggregationHandleCountTest : public::testing::Test { typedef AggregationHandleCountTest AggregationHandleCountDeathTest; TEST_F(AggregationHandleCountTest, CountStarTest) { - checkAggregationCountNullary(0), - checkAggregationCountNullary(10000); + checkAggregationCountNullary(0), checkAggregationCountNullary(10000); } TEST_F(AggregationHandleCountTest, CountStarAccumulateTest) { checkAggregationCountNullaryAccumulate(0), - checkAggregationCountNullaryAccumulate(10000); + checkAggregationCountNullaryAccumulate(10000); } TEST_F(AggregationHandleCountTest, IntTypeTest) { @@ -430,7 +452,8 @@ TEST_F(AggregationHandleCountTest, CharTypeColumnVectorTest) { TEST_F(AggregationHandleCountTest, VarCharTypeColumnVectorTest) { checkAggregationCountStringColumnVector<VarCharType, IndirectColumnVector>(0); - checkAggregationCountStringColumnVector<VarCharType, IndirectColumnVector>(10000); + checkAggregationCountStringColumnVector<VarCharType, IndirectColumnVector>( + 10000); } #ifdef QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -460,8 +483,10 @@ TEST_F(AggregationHandleCountTest, CharTypeValueAccessorTest) { } TEST_F(AggregationHandleCountTest, VarCharTypeValueAccessorTest) { - checkAggregationCountStringValueAccessor<VarCharType, IndirectColumnVector>(0); - checkAggregationCountStringValueAccessor<VarCharType, IndirectColumnVector>(10000); + checkAggregationCountStringValueAccessor<VarCharType, IndirectColumnVector>( + 0); + checkAggregationCountStringValueAccessor<VarCharType, IndirectColumnVector>( + 10000); } #endif // QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -486,25 +511,28 @@ TEST_F(AggregationHandleCountTest, GroupByTableMergeTestCount) { initializeHandle(&long_non_null_type); storage_manager_.reset(new StorageManager("./test_count_data")); std::unique_ptr<AggregationStateHashTableBase> source_hash_table( - aggregation_handle_count_->createGroupByHashTable( - HashTableImplType::kSimpleScalarSeparateChaining, + AggregationStateFastHashTableFactory::CreateResizable( + HashTableImplType::kSeparateChaining, std::vector<const Type *>(1, &long_non_null_type), 10, + {aggregation_handle_count_.get()->getPayloadSize()}, + {aggregation_handle_count_.get()}, storage_manager_.get())); std::unique_ptr<AggregationStateHashTableBase> destination_hash_table( - aggregation_handle_count_->createGroupByHashTable( - HashTableImplType::kSimpleScalarSeparateChaining, + AggregationStateFastHashTableFactory::CreateResizable( + HashTableImplType::kSeparateChaining, std::vector<const Type *>(1, &long_non_null_type), 10, + {aggregation_handle_count_.get()->getPayloadSize()}, + {aggregation_handle_count_.get()}, storage_manager_.get())); - AggregationStateHashTable<AggregationStateCount> *destination_hash_table_derived = - static_cast<AggregationStateHashTable<AggregationStateCount> *>( + AggregationStateFastHashTable *destination_hash_table_derived = + static_cast<AggregationStateFastHashTable *>( destination_hash_table.get()); - AggregationStateHashTable<AggregationStateCount> *source_hash_table_derived = - static_cast<AggregationStateHashTable<AggregationStateCount> *>( - source_hash_table.get()); + AggregationStateFastHashTable *source_hash_table_derived = + static_cast<AggregationStateFastHashTable *>(source_hash_table.get()); // TODO(harshad) - Use TemplateUtil::CreateBoolInstantiatedInstance to // generate all the combinations of the bool template arguments and test them. @@ -530,7 +558,8 @@ TEST_F(AggregationHandleCountTest, GroupByTableMergeTestCount) { TypedValue exclusive_key_source_count_val(exclusive_key_source_count); const std::int64_t exclusive_key_destination_count = 1; - TypedValue exclusive_key_destination_count_val(exclusive_key_destination_count); + TypedValue exclusive_key_destination_count_val( + exclusive_key_destination_count); std::unique_ptr<AggregationStateCount> common_key_source_state( static_cast<AggregationStateCount *>( @@ -546,62 +575,86 @@ TEST_F(AggregationHandleCountTest, GroupByTableMergeTestCount) { aggregation_handle_count_->createInitialState())); // Create count value states for keys. - aggregation_handle_count_derived->iterateUnaryInl(common_key_source_state.get(), - common_key_source_count_val); - std::int64_t actual_val = aggregation_handle_count_->finalize(*common_key_source_state) - .getLiteral<std::int64_t>(); + aggregation_handle_count_derived->iterateUnaryInl( + common_key_source_state.get(), common_key_source_count_val); + std::int64_t actual_val = + aggregation_handle_count_->finalize(*common_key_source_state) + .getLiteral<std::int64_t>(); EXPECT_EQ(common_key_source_count_val.getLiteral<std::int64_t>(), actual_val); aggregation_handle_count_derived->iterateUnaryInl( common_key_destination_state.get(), common_key_destination_count_val); - actual_val = aggregation_handle_count_->finalize(*common_key_destination_state) - .getLiteral<std::int64_t>(); - EXPECT_EQ(common_key_destination_count_val.getLiteral<std::int64_t>(), actual_val); + actual_val = + aggregation_handle_count_->finalize(*common_key_destination_state) + .getLiteral<std::int64_t>(); + EXPECT_EQ(common_key_destination_count_val.getLiteral<std::int64_t>(), + actual_val); aggregation_handle_count_derived->iterateUnaryInl( - exclusive_key_destination_state.get(), exclusive_key_destination_count_val); + exclusive_key_destination_state.get(), + exclusive_key_destination_count_val); actual_val = aggregation_handle_count_->finalize(*exclusive_key_destination_state) .getLiteral<std::int64_t>(); - EXPECT_EQ(exclusive_key_destination_count_val.getLiteral<std::int64_t>(), actual_val); + EXPECT_EQ(exclusive_key_destination_count_val.getLiteral<std::int64_t>(), + actual_val); aggregation_handle_count_derived->iterateUnaryInl( exclusive_key_source_state.get(), exclusive_key_source_count_val); actual_val = aggregation_handle_count_->finalize(*exclusive_key_source_state) .getLiteral<std::int64_t>(); - EXPECT_EQ(exclusive_key_source_count_val.getLiteral<std::int64_t>(), actual_val); + EXPECT_EQ(exclusive_key_source_count_val.getLiteral<std::int64_t>(), + actual_val); // Add the key-state pairs to the hash tables. - source_hash_table_derived->putCompositeKey(common_key, - *common_key_source_state); - destination_hash_table_derived->putCompositeKey( - common_key, *common_key_destination_state); - source_hash_table_derived->putCompositeKey(exclusive_source_key, - *exclusive_key_source_state); - destination_hash_table_derived->putCompositeKey( - exclusive_destination_key, *exclusive_key_destination_state); + unsigned char buffer[100]; + buffer[0] = '\0'; + memcpy(buffer + 1, + common_key_source_state.get()->getPayloadAddress(), + aggregation_handle_count_.get()->getPayloadSize()); + source_hash_table_derived->putCompositeKey(common_key, buffer); + + memcpy(buffer + 1, + common_key_destination_state.get()->getPayloadAddress(), + aggregation_handle_count_.get()->getPayloadSize()); + destination_hash_table_derived->putCompositeKey(common_key, buffer); + + memcpy(buffer + 1, + exclusive_key_source_state.get()->getPayloadAddress(), + aggregation_handle_count_.get()->getPayloadSize()); + source_hash_table_derived->putCompositeKey(exclusive_source_key, buffer); + + memcpy(buffer + 1, + exclusive_key_destination_state.get()->getPayloadAddress(), + aggregation_handle_count_.get()->getPayloadSize()); + destination_hash_table_derived->putCompositeKey(exclusive_destination_key, + buffer); EXPECT_EQ(2u, destination_hash_table_derived->numEntries()); EXPECT_EQ(2u, source_hash_table_derived->numEntries()); - aggregation_handle_count_->mergeGroupByHashTables(*source_hash_table, - destination_hash_table.get()); + AggregationOperationState::mergeGroupByHashTables( + source_hash_table.get(), destination_hash_table.get()); EXPECT_EQ(3u, destination_hash_table_derived->numEntries()); CheckCountValue( common_key_destination_count_val.getLiteral<std::int64_t>() + common_key_source_count_val.getLiteral<std::int64_t>(), - *aggregation_handle_count_derived, - *(destination_hash_table_derived->getSingleCompositeKey(common_key))); - CheckCountValue(exclusive_key_destination_count_val.getLiteral<std::int64_t>(), - *aggregation_handle_count_derived, - *(destination_hash_table_derived->getSingleCompositeKey( - exclusive_destination_key))); + aggregation_handle_count_derived->finalizeHashTableEntryFast( + destination_hash_table_derived->getSingleCompositeKey(common_key) + + 1)); + CheckCountValue( + exclusive_key_destination_count_val.getLiteral<std::int64_t>(), + aggregation_handle_count_derived->finalizeHashTableEntryFast( + destination_hash_table_derived->getSingleCompositeKey( + exclusive_destination_key) + + 1)); CheckCountValue(exclusive_key_source_count_val.getLiteral<std::int64_t>(), - *aggregation_handle_count_derived, - *(source_hash_table_derived->getSingleCompositeKey( - exclusive_source_key))); + aggregation_handle_count_derived->finalizeHashTableEntryFast( + source_hash_table_derived->getSingleCompositeKey( + exclusive_source_key) + + 1)); } } // namespace quickstep
http://git-wip-us.apache.org/repos/asf/incubator-quickstep/blob/ac3512ce/expressions/aggregation/tests/AggregationHandleMax_unittest.cpp ---------------------------------------------------------------------- diff --git a/expressions/aggregation/tests/AggregationHandleMax_unittest.cpp b/expressions/aggregation/tests/AggregationHandleMax_unittest.cpp index b7cf02a..026bd1d 100644 --- a/expressions/aggregation/tests/AggregationHandleMax_unittest.cpp +++ b/expressions/aggregation/tests/AggregationHandleMax_unittest.cpp @@ -31,6 +31,8 @@ #include "expressions/aggregation/AggregationHandle.hpp" #include "expressions/aggregation/AggregationHandleMax.hpp" #include "expressions/aggregation/AggregationID.hpp" +#include "storage/AggregationOperationState.hpp" +#include "storage/FastHashTableFactory.hpp" #include "storage/HashTableBase.hpp" #include "storage/StorageManager.hpp" #include "types/CharType.hpp" @@ -70,54 +72,59 @@ class AggregationHandleMaxTest : public ::testing::Test { // Helper method that calls AggregationHandleMax::iterateUnaryInl() to // aggregate 'value' into '*state'. void iterateHandle(AggregationState *state, const TypedValue &value) { - static_cast<const AggregationHandleMax&>(*aggregation_handle_max_).iterateUnaryInl( - static_cast<AggregationStateMax*>(state), - value); + static_cast<const AggregationHandleMax &>(*aggregation_handle_max_) + .iterateUnaryInl(static_cast<AggregationStateMax *>(state), value); } void initializeHandle(const Type &type) { aggregation_handle_max_.reset( - AggregateFunctionFactory::Get(AggregationID::kMax).createHandle( - std::vector<const Type*>(1, &type))); + AggregateFunctionFactory::Get(AggregationID::kMax) + .createHandle(std::vector<const Type *>(1, &type))); aggregation_handle_max_state_.reset( aggregation_handle_max_->createInitialState()); } static bool ApplyToTypesTest(TypeID typeID) { - const Type &type = (typeID == kChar || typeID == kVarChar) ? - TypeFactory::GetType(typeID, static_cast<std::size_t>(10)) : - TypeFactory::GetType(typeID); + const Type &type = + (typeID == kChar || typeID == kVarChar) + ? TypeFactory::GetType(typeID, static_cast<std::size_t>(10)) + : TypeFactory::GetType(typeID); - return AggregateFunctionFactory::Get(AggregationID::kMax).canApplyToTypes( - std::vector<const Type*>(1, &type)); + return AggregateFunctionFactory::Get(AggregationID::kMax) + .canApplyToTypes(std::vector<const Type *>(1, &type)); } static bool ResultTypeForArgumentTypeTest(TypeID input_type_id, TypeID output_type_id) { - const Type *result_type - = AggregateFunctionFactory::Get(AggregationID::kMax).resultTypeForArgumentTypes( - std::vector<const Type*>(1, &TypeFactory::GetType(input_type_id))); + const Type *result_type = + AggregateFunctionFactory::Get(AggregationID::kMax) + .resultTypeForArgumentTypes(std::vector<const Type *>( + 1, &TypeFactory::GetType(input_type_id))); return (result_type->getTypeID() == output_type_id); } template <typename CppType> - static void CheckMaxValue( - CppType expected, - const AggregationHandle &handle, - const AggregationState &state) { + static void CheckMaxValue(CppType expected, + const AggregationHandle &handle, + const AggregationState &state) { EXPECT_EQ(expected, handle.finalize(state).getLiteral<CppType>()); } - static void CheckMaxString( - const std::string &expected, - const AggregationHandle &handle, - const AggregationState &state) { + template <typename CppType> + static void CheckMaxValue(CppType expected, const TypedValue &value) { + EXPECT_EQ(expected, value.getLiteral<CppType>()); + } + + static void CheckMaxString(const std::string &expected, + const AggregationHandle &handle, + const AggregationState &state) { TypedValue value = handle.finalize(state); ASSERT_EQ(expected.length(), value.getAsciiStringLength()); - EXPECT_EQ(0, std::strncmp(expected.c_str(), - static_cast<const char*>(value.getDataPtr()), - value.getAsciiStringLength())); + EXPECT_EQ(0, + std::strncmp(expected.c_str(), + static_cast<const char *>(value.getDataPtr()), + value.getAsciiStringLength())); } // Static templated method to initialize data types. @@ -130,7 +137,9 @@ class AggregationHandleMaxTest : public ::testing::Test { void checkAggregationMaxGeneric() { const GenericType &type = GenericType::Instance(true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_max_->finalize(*aggregation_handle_max_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_max_->finalize(*aggregation_handle_max_state_) + .isNull()); typename GenericType::cpptype val; typename GenericType::cpptype max; @@ -142,16 +151,18 @@ class AggregationHandleMaxTest : public ::testing::Test { if (type.getTypeID() == kInt || type.getTypeID() == kLong) { SetDataType(i * kNumSamples + j - 10, &val); } else { - SetDataType(static_cast<float>(i * kNumSamples + j - 10)/10, &val); + SetDataType(static_cast<float>(i * kNumSamples + j - 10) / 10, &val); } - iterateHandle(aggregation_handle_max_state_.get(), type.makeValue(&val)); + iterateHandle(aggregation_handle_max_state_.get(), + type.makeValue(&val)); if (max < val) { max = val; } } } iterateHandle(aggregation_handle_max_state_.get(), type.makeNullValue()); - CheckMaxValue<typename GenericType::cpptype>(max, *aggregation_handle_max_, *aggregation_handle_max_state_); + CheckMaxValue<typename GenericType::cpptype>( + max, *aggregation_handle_max_, *aggregation_handle_max_state_); // Test mergeStates(). std::unique_ptr<AggregationState> merge_state( @@ -165,7 +176,7 @@ class AggregationHandleMaxTest : public ::testing::Test { if (type.getTypeID() == kInt || type.getTypeID() == kLong) { SetDataType(i * kNumSamples + j - 20, &val); } else { - SetDataType(static_cast<float>(i * kNumSamples + j - 20)/10, &val); + SetDataType(static_cast<float>(i * kNumSamples + j - 20) / 10, &val); } iterateHandle(merge_state.get(), type.makeValue(&val)); if (max < val) { @@ -176,14 +187,14 @@ class AggregationHandleMaxTest : public ::testing::Test { aggregation_handle_max_->mergeStates(*merge_state, aggregation_handle_max_state_.get()); CheckMaxValue<typename GenericType::cpptype>( - max, - *aggregation_handle_max_, - *aggregation_handle_max_state_); + max, *aggregation_handle_max_, *aggregation_handle_max_state_); } template <typename GenericType> - ColumnVector *createColumnVectorGeneric(const Type &type, typename GenericType::cpptype *max) { - NativeColumnVector *column = new NativeColumnVector(type, kIterations * kNumSamples + 3); + ColumnVector* createColumnVectorGeneric(const Type &type, + typename GenericType::cpptype *max) { + NativeColumnVector *column = + new NativeColumnVector(type, kIterations * kNumSamples + 3); typename GenericType::cpptype val; SetDataType(0, max); @@ -194,7 +205,7 @@ class AggregationHandleMaxTest : public ::testing::Test { if (type.getTypeID() == kInt || type.getTypeID() == kLong) { SetDataType(i * kNumSamples + j - 10, &val); } else { - SetDataType(static_cast<float>(i * kNumSamples + j - 10)/10, &val); + SetDataType(static_cast<float>(i * kNumSamples + j - 10) / 10, &val); } column->appendTypedValue(type.makeValue(&val)); if (*max < val) { @@ -202,7 +213,7 @@ class AggregationHandleMaxTest : public ::testing::Test { } } // One NULL in the middle. - if (i == kIterations/2) { + if (i == kIterations / 2) { column->appendTypedValue(type.makeNullValue()); } } @@ -215,11 +226,14 @@ class AggregationHandleMaxTest : public ::testing::Test { void checkAggregationMaxGenericColumnVector() { const GenericType &type = GenericType::Instance(true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_max_->finalize(*aggregation_handle_max_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_max_->finalize(*aggregation_handle_max_state_) + .isNull()); typename GenericType::cpptype max; std::vector<std::unique_ptr<ColumnVector>> column_vectors; - column_vectors.emplace_back(createColumnVectorGeneric<GenericType>(type, &max)); + column_vectors.emplace_back( + createColumnVectorGeneric<GenericType>(type, &max)); std::unique_ptr<AggregationState> cv_state( aggregation_handle_max_->accumulateColumnVectors(column_vectors)); @@ -227,15 +241,12 @@ class AggregationHandleMaxTest : public ::testing::Test { // Test the state generated directly by accumulateColumnVectors(), and also // test after merging back. CheckMaxValue<typename GenericType::cpptype>( - max, - *aggregation_handle_max_, - *cv_state); + max, *aggregation_handle_max_, *cv_state); - aggregation_handle_max_->mergeStates(*cv_state, aggregation_handle_max_state_.get()); + aggregation_handle_max_->mergeStates(*cv_state, + aggregation_handle_max_state_.get()); CheckMaxValue<typename GenericType::cpptype>( - max, - *aggregation_handle_max_, - *aggregation_handle_max_state_); + max, *aggregation_handle_max_, *aggregation_handle_max_state_); } #ifdef QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -243,29 +254,29 @@ class AggregationHandleMaxTest : public ::testing::Test { void checkAggregationMaxGenericValueAccessor() { const GenericType &type = GenericType::Instance(true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_max_->finalize(*aggregation_handle_max_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_max_->finalize(*aggregation_handle_max_state_) + .isNull()); - std::unique_ptr<ColumnVectorsValueAccessor> accessor(new ColumnVectorsValueAccessor()); + std::unique_ptr<ColumnVectorsValueAccessor> accessor( + new ColumnVectorsValueAccessor()); typename GenericType::cpptype max; accessor->addColumn(createColumnVectorGeneric<GenericType>(type, &max)); std::unique_ptr<AggregationState> va_state( - aggregation_handle_max_->accumulateValueAccessor(accessor.get(), - std::vector<attribute_id>(1, 0))); + aggregation_handle_max_->accumulateValueAccessor( + accessor.get(), std::vector<attribute_id>(1, 0))); // Test the state generated directly by accumulateValueAccessor(), and also // test after merging back. CheckMaxValue<typename GenericType::cpptype>( - max, - *aggregation_handle_max_, - *va_state); + max, *aggregation_handle_max_, *va_state); - aggregation_handle_max_->mergeStates(*va_state, aggregation_handle_max_state_.get()); + aggregation_handle_max_->mergeStates(*va_state, + aggregation_handle_max_state_.get()); CheckMaxValue<typename GenericType::cpptype>( - max, - *aggregation_handle_max_, - *aggregation_handle_max_state_); + max, *aggregation_handle_max_, *aggregation_handle_max_state_); } #endif // QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -273,11 +284,14 @@ class AggregationHandleMaxTest : public ::testing::Test { void checkAggregationMaxString() { const StringType &type = StringType::Instance(10, true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_max_->finalize(*aggregation_handle_max_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_max_->finalize(*aggregation_handle_max_state_) + .isNull()); std::unique_ptr<UncheckedComparator> fast_comparator_; - fast_comparator_.reset(ComparisonFactory::GetComparison(ComparisonID::kGreater) - .makeUncheckedComparatorForTypes(type, type)); + fast_comparator_.reset( + ComparisonFactory::GetComparison(ComparisonID::kGreater) + .makeUncheckedComparatorForTypes(type, type)); std::string string_literal; std::string max = ""; int val; @@ -291,15 +305,17 @@ class AggregationHandleMaxTest : public ::testing::Test { iterateHandle( aggregation_handle_max_state_.get(), - type.makeValue(string_literal.c_str(), - string_literal.length() + 1).ensureNotReference()); - if (fast_comparator_->compareDataPtrs(string_literal.c_str(), max.c_str())) { + type.makeValue(string_literal.c_str(), string_literal.length() + 1) + .ensureNotReference()); + if (fast_comparator_->compareDataPtrs(string_literal.c_str(), + max.c_str())) { max = string_literal; } } } iterateHandle(aggregation_handle_max_state_.get(), type.makeNullValue()); - CheckMaxString(max, *aggregation_handle_max_, *aggregation_handle_max_state_); + CheckMaxString( + max, *aggregation_handle_max_, *aggregation_handle_max_state_); // Test mergeStates(). std::unique_ptr<AggregationState> merge_state( @@ -317,24 +333,28 @@ class AggregationHandleMaxTest : public ::testing::Test { iterateHandle( merge_state.get(), - type.makeValue(string_literal.c_str(), - string_literal.length() + 1).ensureNotReference()); - if (fast_comparator_->compareDataPtrs(string_literal.c_str(), max.c_str())) { + type.makeValue(string_literal.c_str(), string_literal.length() + 1) + .ensureNotReference()); + if (fast_comparator_->compareDataPtrs(string_literal.c_str(), + max.c_str())) { max = string_literal; } } } aggregation_handle_max_->mergeStates(*merge_state, aggregation_handle_max_state_.get()); - CheckMaxString(max, *aggregation_handle_max_, *aggregation_handle_max_state_); + CheckMaxString( + max, *aggregation_handle_max_, *aggregation_handle_max_state_); } template <typename ColumnVectorType> - ColumnVector *createColumnVectorString(const Type &type, std::string *max) { - ColumnVectorType *column = new ColumnVectorType(type, kIterations * kNumSamples + 3); + ColumnVector* createColumnVectorString(const Type &type, std::string *max) { + ColumnVectorType *column = + new ColumnVectorType(type, kIterations * kNumSamples + 3); std::unique_ptr<UncheckedComparator> fast_comparator_; - fast_comparator_.reset(ComparisonFactory::GetComparison(ComparisonID::kGreater) - .makeUncheckedComparatorForTypes(type, type)); + fast_comparator_.reset( + ComparisonFactory::GetComparison(ComparisonID::kGreater) + .makeUncheckedComparatorForTypes(type, type)); std::string string_literal; *max = ""; int val; @@ -346,14 +366,16 @@ class AggregationHandleMaxTest : public ::testing::Test { oss << "max" << val; string_literal = oss.str(); - column->appendTypedValue(type.makeValue(string_literal.c_str(), string_literal.length() + 1) - .ensureNotReference()); - if (fast_comparator_->compareDataPtrs(string_literal.c_str(), max->c_str())) { + column->appendTypedValue( + type.makeValue(string_literal.c_str(), string_literal.length() + 1) + .ensureNotReference()); + if (fast_comparator_->compareDataPtrs(string_literal.c_str(), + max->c_str())) { *max = string_literal; } } // One NULL in the middle. - if (i == kIterations/2) { + if (i == kIterations / 2) { column->appendTypedValue(type.makeNullValue()); } } @@ -366,25 +388,26 @@ class AggregationHandleMaxTest : public ::testing::Test { void checkAggregationMaxStringColumnVector() { const StringType &type = StringType::Instance(10, true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_max_->finalize(*aggregation_handle_max_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_max_->finalize(*aggregation_handle_max_state_) + .isNull()); std::string max; std::vector<std::unique_ptr<ColumnVector>> column_vectors; - column_vectors.emplace_back(createColumnVectorString<ColumnVectorType>(type, &max)); + column_vectors.emplace_back( + createColumnVectorString<ColumnVectorType>(type, &max)); std::unique_ptr<AggregationState> cv_state( aggregation_handle_max_->accumulateColumnVectors(column_vectors)); // Test the state generated directly by accumulateColumnVectors(), and also // test after merging back. - CheckMaxString(max, - *aggregation_handle_max_, - *cv_state); - - aggregation_handle_max_->mergeStates(*cv_state, aggregation_handle_max_state_.get()); - CheckMaxString(max, - *aggregation_handle_max_, - *aggregation_handle_max_state_); + CheckMaxString(max, *aggregation_handle_max_, *cv_state); + + aggregation_handle_max_->mergeStates(*cv_state, + aggregation_handle_max_state_.get()); + CheckMaxString( + max, *aggregation_handle_max_, *aggregation_handle_max_state_); } #ifdef QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -392,26 +415,27 @@ class AggregationHandleMaxTest : public ::testing::Test { void checkAggregationMaxStringValueAccessor() { const StringType &type = StringType::Instance(10, true); initializeHandle(type); - EXPECT_TRUE(aggregation_handle_max_->finalize(*aggregation_handle_max_state_).isNull()); + EXPECT_TRUE( + aggregation_handle_max_->finalize(*aggregation_handle_max_state_) + .isNull()); std::string max; - std::unique_ptr<ColumnVectorsValueAccessor> accessor(new ColumnVectorsValueAccessor()); + std::unique_ptr<ColumnVectorsValueAccessor> accessor( + new ColumnVectorsValueAccessor()); accessor->addColumn(createColumnVectorString<ColumnVectorType>(type, &max)); std::unique_ptr<AggregationState> va_state( - aggregation_handle_max_->accumulateValueAccessor(accessor.get(), - std::vector<attribute_id>(1, 0))); + aggregation_handle_max_->accumulateValueAccessor( + accessor.get(), std::vector<attribute_id>(1, 0))); // Test the state generated directly by accumulateValueAccessor(), and also // test after merging back. - CheckMaxString(max, - *aggregation_handle_max_, - *va_state); - - aggregation_handle_max_->mergeStates(*va_state, aggregation_handle_max_state_.get()); - CheckMaxString(max, - *aggregation_handle_max_, - *aggregation_handle_max_state_); + CheckMaxString(max, *aggregation_handle_max_, *va_state); + + aggregation_handle_max_->mergeStates(*va_state, + aggregation_handle_max_state_.get()); + CheckMaxString( + max, *aggregation_handle_max_, *aggregation_handle_max_state_); } #endif // QUICKSTEP_ENABLE_VECTOR_COPY_ELISION_SELECTION @@ -422,9 +446,7 @@ class AggregationHandleMaxTest : public ::testing::Test { template <> void AggregationHandleMaxTest::CheckMaxValue<float>( - float val, - const AggregationHandle &handle, - const AggregationState &state) { + float val, const AggregationHandle &handle, const AggregationState &state) { EXPECT_FLOAT_EQ(val, handle.finalize(state).getLiteral<float>()); } @@ -437,17 +459,20 @@ void AggregationHandleMaxTest::CheckMaxValue<double>( } template <> -void AggregationHandleMaxTest::SetDataType<DatetimeLit>(int value, DatetimeLit *data) { +void AggregationHandleMaxTest::SetDataType<DatetimeLit>(int value, + DatetimeLit *data) { data->ticks = value; } template <> -void AggregationHandleMaxTest::SetDataType<DatetimeIntervalLit>(int value, DatetimeIntervalLit *data) { +void AggregationHandleMaxTest::SetDataType<DatetimeIntervalLit>( + int value, DatetimeIntervalLit *data) { data->interval_ticks = value; } template <> -void AggregationHandleMaxTest::SetDataType<YearMonthIntervalLit>(int value, YearMonthIntervalLit *data) { +void AggregationHandleMaxTest::SetDataType<YearMonthIntervalLit>( + int value, YearMonthIntervalLit *data) { data->months = value; } @@ -579,50 +604,67 @@ TEST_F(AggregationHandleMaxDeathTest, WrongTypeTest) { float float_val = 0; // Passes. - iterateHandle(aggregation_handle_max_state_.get(), int_non_null_type.makeValue(&int_val)); + iterateHandle(aggregation_handle_max_state_.get(), + int_non_null_type.makeValue(&int_val)); - EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), long_type.makeValue(&long_val)), ""); - EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), double_type.makeValue(&double_val)), ""); - EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), float_type.makeValue(&float_val)), ""); - EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), char_type.makeValue("asdf", 5)), ""); - EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), varchar_type.makeValue("asdf", 5)), ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), + long_type.makeValue(&long_val)), + ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), + double_type.makeValue(&double_val)), + ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), + float_type.makeValue(&float_val)), + ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), + char_type.makeValue("asdf", 5)), + ""); + EXPECT_DEATH(iterateHandle(aggregation_handle_max_state_.get(), + varchar_type.makeValue("asdf", 5)), + ""); // Test mergeStates() with incorrectly typed handles. std::unique_ptr<AggregationHandle> aggregation_handle_max_long( - AggregateFunctionFactory::Get(AggregationID::kMax).createHandle( - std::vector<const Type*>(1, &long_type))); + AggregateFunctionFactory::Get(AggregationID::kMax) + .createHandle(std::vector<const Type *>(1, &long_type))); std::unique_ptr<AggregationState> aggregation_state_max_merge_long( aggregation_handle_max_long->createInitialState()); - static_cast<const AggregationHandleMax&>(*aggregation_handle_max_long).iterateUnaryInl( - static_cast<AggregationStateMax*>(aggregation_state_max_merge_long.get()), - long_type.makeValue(&long_val)); - EXPECT_DEATH(aggregation_handle_max_->mergeStates(*aggregation_state_max_merge_long, - aggregation_handle_max_state_.get()), - ""); + static_cast<const AggregationHandleMax &>(*aggregation_handle_max_long) + .iterateUnaryInl(static_cast<AggregationStateMax *>( + aggregation_state_max_merge_long.get()), + long_type.makeValue(&long_val)); + EXPECT_DEATH( + aggregation_handle_max_->mergeStates(*aggregation_state_max_merge_long, + aggregation_handle_max_state_.get()), + ""); std::unique_ptr<AggregationHandle> aggregation_handle_max_double( - AggregateFunctionFactory::Get(AggregationID::kMax).createHandle( - std::vector<const Type*>(1, &double_type))); + AggregateFunctionFactory::Get(AggregationID::kMax) + .createHandle(std::vector<const Type *>(1, &double_type))); std::unique_ptr<AggregationState> aggregation_state_max_merge_double( aggregation_handle_max_double->createInitialState()); - static_cast<const AggregationHandleMax&>(*aggregation_handle_max_double).iterateUnaryInl( - static_cast<AggregationStateMax*>(aggregation_state_max_merge_double.get()), - double_type.makeValue(&double_val)); - EXPECT_DEATH(aggregation_handle_max_->mergeStates(*aggregation_state_max_merge_double, - aggregation_handle_max_state_.get()), - ""); + static_cast<const AggregationHandleMax &>(*aggregation_handle_max_double) + .iterateUnaryInl(static_cast<AggregationStateMax *>( + aggregation_state_max_merge_double.get()), + double_type.makeValue(&double_val)); + EXPECT_DEATH( + aggregation_handle_max_->mergeStates(*aggregation_state_max_merge_double, + aggregation_handle_max_state_.get()), + ""); std::unique_ptr<AggregationHandle> aggregation_handle_max_float( - AggregateFunctionFactory::Get(AggregationID::kMax).createHandle( - std::vector<const Type*>(1, &float_type))); + AggregateFunctionFactory::Get(AggregationID::kMax) + .createHandle(std::vector<const Type *>(1, &float_type))); std::unique_ptr<AggregationState> aggregation_state_max_merge_float( aggregation_handle_max_float->createInitialState()); - static_cast<const AggregationHandleMax&>(*aggregation_handle_max_float).iterateUnaryInl( - static_cast<AggregationStateMax*>(aggregation_state_max_merge_float.get()), - float_type.makeValue(&float_val)); - EXPECT_DEATH(aggregation_handle_max_->mergeStates(*aggregation_state_max_merge_float, - aggregation_handle_max_state_.get()), - ""); + static_cast<const AggregationHandleMax &>(*aggregation_handle_max_float) + .iterateUnaryInl(static_cast<AggregationStateMax *>( + aggregation_state_max_merge_float.get()), + float_type.makeValue(&float_val)); + EXPECT_DEATH( + aggregation_handle_max_->mergeStates(*aggregation_state_max_merge_float, + aggregation_handle_max_state_.get()), + ""); } #endif @@ -647,25 +689,28 @@ TEST_F(AggregationHandleMaxTest, GroupByTableMergeTest) { initializeHandle(int_non_null_type); storage_manager_.reset(new StorageManager("./test_max_data")); std::unique_ptr<AggregationStateHashTableBase> source_hash_table( - aggregation_handle_max_->createGroupByHashTable( - HashTableImplType::kSimpleScalarSeparateChaining, + AggregationStateFastHashTableFactory::CreateResizable( + HashTableImplType::kSeparateChaining, std::vector<const Type *>(1, &int_non_null_type), 10, + {aggregation_handle_max_.get()->getPayloadSize()}, + {aggregation_handle_max_.get()}, storage_manager_.get())); std::unique_ptr<AggregationStateHashTableBase> destination_hash_table( - aggregation_handle_max_->createGroupByHashTable( - HashTableImplType::kSimpleScalarSeparateChaining, + AggregationStateFastHashTableFactory::CreateResizable( + HashTableImplType::kSeparateChaining, std::vector<const Type *>(1, &int_non_null_type), 10, + {aggregation_handle_max_.get()->getPayloadSize()}, + {aggregation_handle_max_.get()}, storage_manager_.get())); - AggregationStateHashTable<AggregationStateMax> *destination_hash_table_derived = - static_cast<AggregationStateHashTable<AggregationStateMax> *>( + AggregationStateFastHashTable *destination_hash_table_derived = + static_cast<AggregationStateFastHashTable *>( destination_hash_table.get()); - AggregationStateHashTable<AggregationStateMax> *source_hash_table_derived = - static_cast<AggregationStateHashTable<AggregationStateMax> *>( - source_hash_table.get()); + AggregationStateFastHashTable *source_hash_table_derived = + static_cast<AggregationStateFastHashTable *>(source_hash_table.get()); AggregationHandleMax *aggregation_handle_max_derived = static_cast<AggregationHandleMax *>(aggregation_handle_max_.get()); @@ -730,35 +775,52 @@ TEST_F(AggregationHandleMaxTest, GroupByTableMergeTest) { EXPECT_EQ(exclusive_key_source_max_val.getLiteral<int>(), actual_val); // Add the key-state pairs to the hash tables. - source_hash_table_derived->putCompositeKey(common_key, - *common_key_source_state); - destination_hash_table_derived->putCompositeKey( - common_key, *common_key_destination_state); - source_hash_table_derived->putCompositeKey(exclusive_source_key, - *exclusive_key_source_state); - destination_hash_table_derived->putCompositeKey( - exclusive_destination_key, *exclusive_key_destination_state); + unsigned char buffer[100]; + buffer[0] = '\0'; + memcpy(buffer + 1, + common_key_source_state.get()->getPayloadAddress(), + aggregation_handle_max_.get()->getPayloadSize()); + source_hash_table_derived->putCompositeKey(common_key, buffer); + + memcpy(buffer + 1, + common_key_destination_state.get()->getPayloadAddress(), + aggregation_handle_max_.get()->getPayloadSize()); + destination_hash_table_derived->putCompositeKey(common_key, buffer); + + memcpy(buffer + 1, + exclusive_key_source_state.get()->getPayloadAddress(), + aggregation_handle_max_.get()->getPayloadSize()); + source_hash_table_derived->putCompositeKey(exclusive_source_key, buffer); + + memcpy(buffer + 1, + exclusive_key_destination_state.get()->getPayloadAddress(), + aggregation_handle_max_.get()->getPayloadSize()); + destination_hash_table_derived->putCompositeKey(exclusive_destination_key, + buffer); EXPECT_EQ(2u, destination_hash_table_derived->numEntries()); EXPECT_EQ(2u, source_hash_table_derived->numEntries()); - aggregation_handle_max_->mergeGroupByHashTables(*source_hash_table, - destination_hash_table.get()); + AggregationOperationState::mergeGroupByHashTables( + source_hash_table.get(), destination_hash_table.get()); EXPECT_EQ(3u, destination_hash_table_derived->numEntries()); CheckMaxValue<int>( common_key_destination_max_val.getLiteral<int>(), - *aggregation_handle_max_derived, - *(destination_hash_table_derived->getSingleCompositeKey(common_key))); + aggregation_handle_max_derived->finalizeHashTableEntryFast( + destination_hash_table_derived->getSingleCompositeKey(common_key) + + 1)); CheckMaxValue<int>(exclusive_key_destination_max_val.getLiteral<int>(), - *aggregation_handle_max_derived, - *(destination_hash_table_derived->getSingleCompositeKey( - exclusive_destination_key))); + aggregation_handle_max_derived->finalizeHashTableEntryFast( + destination_hash_table_derived->getSingleCompositeKey( + exclusive_destination_key) + + 1)); CheckMaxValue<int>(exclusive_key_source_max_val.getLiteral<int>(), - *aggregation_handle_max_derived, - *(source_hash_table_derived->getSingleCompositeKey( - exclusive_source_key))); + aggregation_handle_max_derived->finalizeHashTableEntryFast( + source_hash_table_derived->getSingleCompositeKey( + exclusive_source_key) + + 1)); } } // namespace quickstep