This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch count_min_python in repository https://gitbox.apache.org/repos/asf/datasketches-cpp.git
commit c8cf06f434bd20bacf149f7793a8e4b7df6f45f0 Author: Jon Malkin <[email protected]> AuthorDate: Thu Apr 6 00:17:33 2023 -0700 Add apache license header to files. Add allocator support, clean up serialized size --- count/include/count_min.hpp | 58 ++++++-- count/include/count_min_impl.hpp | 223 ++++++++++++++++--------------- count/test/CMakeLists.txt | 1 + count/test/count_min_allocation_test.cpp | 155 +++++++++++++++++++++ count/test/count_min_test.cpp | 26 +++- 5 files changed, 344 insertions(+), 119 deletions(-) diff --git a/count/include/count_min.hpp b/count/include/count_min.hpp index c4bd752..c54fb25 100644 --- a/count/include/count_min.hpp +++ b/count/include/count_min.hpp @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + #ifndef COUNT_MIN_HPP_ #define COUNT_MIN_HPP_ @@ -14,12 +33,13 @@ namespace datasketches { * @author Charlie Dickens */ -template<typename W> +template <typename W, + typename Allocator = std::allocator<W>> class count_min_sketch{ static_assert(std::is_arithmetic<W>::value, "Arithmetic type expected"); public: + using allocator_type = Allocator; - using vector_bytes = std::vector<uint8_t>; /** * Creates an instance of the sketch given parameters _num_hashes, _num_buckets and hash seed, `seed`. * @param num_hashes : number of hash functions in the sketch. Equivalently the number of rows in the array @@ -29,7 +49,7 @@ public: * The items inserted into the sketch can be arbitrary type, so long as they are hashable via murmurhash. * Only update and estimate methods are added for uint64_t and string types. */ - count_min_sketch(uint8_t num_hashes, uint32_t num_buckets, uint64_t seed = DEFAULT_SEED) ; + count_min_sketch(uint8_t num_hashes, uint32_t num_buckets, uint64_t seed = DEFAULT_SEED, const Allocator& allocator = Allocator()) ; /** * @return configured _num_hashes of this sketch @@ -169,7 +189,7 @@ public: /* * merges a separate count_min_sketch into this count_min_sketch. */ - void merge(const count_min_sketch<W> &other_sketch) ; + void merge(const count_min_sketch &other_sketch) ; /** * Returns true if this sketch is empty. @@ -183,7 +203,7 @@ public: * @brief Returns a string describing the sketch * @return A string with a human-readable description of the sketch */ - string<std::allocator<W>> to_string() const; + string<Allocator> to_string() const; // Iterators using const_iterator = typename std::vector<W>::const_iterator ; @@ -238,8 +258,22 @@ public: * */ + + /** + * Computes size needed to serialize the current state of the sketch. + * @return size in bytes needed to serialize this sketch + */ + size_t get_serialized_size_bytes() const; + + /** + * This method serializes a binary image of the sketch to an output stream. + */ void serialize(std::ostream& os) const; + // This is a convenience alias for users + // The type returned by the following serialize method + using vector_bytes = std::vector<uint8_t, typename std::allocator_traits<Allocator>::template rebind_alloc<uint8_t>>; + /** * This method serializes the sketch as a vector of bytes. * An optional header can be reserved in front of the sketch. @@ -255,8 +289,7 @@ public: * @param seed the seed for the hash function that was used to create the sketch * @return an instance of a sketch */ - //static count_min_sketch deserialize(std::istream& is, uint64_t seed=DEFAULT_SEED) const; - static count_min_sketch deserialize(std::istream& is, uint64_t seed) ; + static count_min_sketch deserialize(std::istream& is, uint64_t seed=DEFAULT_SEED, const Allocator& allocator = Allocator()); /** * This method deserializes a sketch from a given array of bytes. @@ -265,12 +298,19 @@ public: * @param seed the seed for the hash function that was used to create the sketch * @return an instance of the sketch */ - static count_min_sketch deserialize(const void* bytes, size_t size, uint64_t seed=DEFAULT_SEED); + static count_min_sketch deserialize(const void* bytes, size_t size, uint64_t seed=DEFAULT_SEED, const Allocator& allocator = Allocator()); + + /** + * Returns the allocator for this sketch. + * @return allocator + */ + allocator_type get_allocator() const; private: + Allocator _allocator; uint8_t _num_hashes ; uint32_t _num_buckets ; - std::vector<W> _sketch_array ; // the array stored by the sketch + std::vector<W, Allocator> _sketch_array ; // the array stored by the sketch uint64_t _seed ; W _total_weight ; std::vector<uint64_t> hash_seeds ; diff --git a/count/include/count_min_impl.hpp b/count/include/count_min_impl.hpp index 8a1551f..7f414a5 100644 --- a/count/include/count_min_impl.hpp +++ b/count/include/count_min_impl.hpp @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + #ifndef COUNT_MIN_IMPL_HPP_ #define COUNT_MIN_IMPL_HPP_ @@ -11,11 +30,12 @@ namespace datasketches { -template<typename W> -count_min_sketch<W>::count_min_sketch(uint8_t num_hashes, uint32_t num_buckets, uint64_t seed): +template<typename W, typename A> +count_min_sketch<W,A>::count_min_sketch(uint8_t num_hashes, uint32_t num_buckets, uint64_t seed, const A& allocator): +_allocator(allocator), _num_hashes(num_hashes), _num_buckets(num_buckets), -_sketch_array((num_hashes*num_buckets < 1<<30) ? num_hashes*num_buckets : 0, 0), +_sketch_array((num_hashes*num_buckets < 1<<30) ? num_hashes*num_buckets : 0, 0, _allocator), _seed(seed), _total_weight(0){ if(num_buckets < 3) throw std::invalid_argument("Using fewer than 3 buckets incurs relative error greater than 1.") ; @@ -27,7 +47,6 @@ _total_weight(0){ "Try reducing either the number of buckets or the number of hash functions.") ; } - std::default_random_engine rng(_seed); std::uniform_int_distribution<uint64_t> extra_hash_seeds(0, std::numeric_limits<uint64_t>::max()); hash_seeds.reserve(num_hashes) ; @@ -37,33 +56,33 @@ _total_weight(0){ } } -template<typename W> -uint8_t count_min_sketch<W>::get_num_hashes() const{ +template<typename W, typename A> +uint8_t count_min_sketch<W,A>::get_num_hashes() const{ return _num_hashes ; } -template<typename W> -uint32_t count_min_sketch<W>::get_num_buckets() const{ +template<typename W, typename A> +uint32_t count_min_sketch<W,A>::get_num_buckets() const{ return _num_buckets ; } -template<typename W> -uint64_t count_min_sketch<W>::get_seed() const { +template<typename W, typename A> +uint64_t count_min_sketch<W,A>::get_seed() const { return _seed ; } -template<typename W> -double count_min_sketch<W>::get_relative_error() const{ +template<typename W, typename A> +double count_min_sketch<W,A>::get_relative_error() const{ return exp(1.0) / double(_num_buckets) ; } -template<typename W> -W count_min_sketch<W>::get_total_weight() const{ +template<typename W, typename A> +W count_min_sketch<W,A>::get_total_weight() const{ return _total_weight ; } -template<typename W> -uint32_t count_min_sketch<W>::suggest_num_buckets(double relative_error){ +template<typename W, typename A> +uint32_t count_min_sketch<W,A>::suggest_num_buckets(double relative_error){ /* * Function to help users select a number of buckets for a given error. * TODO: Change this when we use only power of 2 buckets. @@ -75,8 +94,8 @@ uint32_t count_min_sketch<W>::suggest_num_buckets(double relative_error){ return ceil(exp(1.0) / relative_error); } -template<typename W> -uint8_t count_min_sketch<W>::suggest_num_hashes(double confidence){ +template<typename W, typename A> +uint8_t count_min_sketch<W,A>::suggest_num_hashes(double confidence){ /* * Function to help users select a number of hashes for a given confidence * e.g. confidence = 1 - failure probability @@ -88,8 +107,8 @@ uint8_t count_min_sketch<W>::suggest_num_hashes(double confidence){ return std::min<uint8_t>( ceil(log(1.0/(1.0 - confidence))), UINT8_MAX) ; } -template<typename W> -std::vector<uint64_t> count_min_sketch<W>::get_hashes(const void* item, size_t size) const{ +template<typename W, typename A> +std::vector<uint64_t> count_min_sketch<W,A>::get_hashes(const void* item, size_t size) const{ /* * Returns the hash locations for the input item using the original hashing * scheme from [1]. @@ -120,20 +139,20 @@ std::vector<uint64_t> count_min_sketch<W>::get_hashes(const void* item, size_t s return sketch_update_locations ; } -template<typename W> -W count_min_sketch<W>::get_estimate(uint64_t item) const {return get_estimate(&item, sizeof(item));} +template<typename W, typename A> +W count_min_sketch<W,A>::get_estimate(uint64_t item) const {return get_estimate(&item, sizeof(item));} -template<typename W> -W count_min_sketch<W>::get_estimate(int64_t item) const {return get_estimate(&item, sizeof(item));} +template<typename W, typename A> +W count_min_sketch<W,A>::get_estimate(int64_t item) const {return get_estimate(&item, sizeof(item));} -template<typename W> -W count_min_sketch<W>::get_estimate(const std::string& item) const { +template<typename W, typename A> +W count_min_sketch<W,A>::get_estimate(const std::string& item) const { if (item.empty()) return 0 ; // Empty strings are not inserted into the sketch. return get_estimate(item.c_str(), item.length()); } -template<typename W> -W count_min_sketch<W>::get_estimate(const void* item, size_t size) const { +template<typename W, typename A> +W count_min_sketch<W,A>::get_estimate(const void* item, size_t size) const { /* * Returns the estimated frequency of the item */ @@ -146,40 +165,40 @@ W count_min_sketch<W>::get_estimate(const void* item, size_t size) const { return result ; } -template<typename W> -void count_min_sketch<W>::update(uint64_t item, W weight) { +template<typename W, typename A> +void count_min_sketch<W,A>::update(uint64_t item, W weight) { update(&item, sizeof(item), weight); } -template<typename W> -void count_min_sketch<W>::update(uint64_t item) { +template<typename W, typename A> +void count_min_sketch<W,A>::update(uint64_t item) { update(&item, sizeof(item), 1); } -template<typename W> -void count_min_sketch<W>::update(int64_t item, W weight) { +template<typename W, typename A> +void count_min_sketch<W,A>::update(int64_t item, W weight) { update(&item, sizeof(item), weight); } -template<typename W> -void count_min_sketch<W>::update(int64_t item) { +template<typename W, typename A> +void count_min_sketch<W,A>::update(int64_t item) { update(&item, sizeof(item), 1); } -template<typename W> -void count_min_sketch<W>::update(const std::string& item, W weight) { +template<typename W, typename A> +void count_min_sketch<W,A>::update(const std::string& item, W weight) { if (item.empty()) return; update(item.c_str(), item.length(), weight); } -template<typename W> -void count_min_sketch<W>::update(const std::string& item) { +template<typename W, typename A> +void count_min_sketch<W,A>::update(const std::string& item) { if (item.empty()) return; update(item.c_str(), item.length(), 1); } -template<typename W> -void count_min_sketch<W>::update(const void* item, size_t size, W weight) { +template<typename W, typename A> +void count_min_sketch<W,A>::update(const void* item, size_t size, W weight) { /* * Gets the item's hash locations and then increments the sketch in those * locations by the weight. @@ -192,42 +211,42 @@ void count_min_sketch<W>::update(const void* item, size_t size, W weight) { } } -template<typename W> -W count_min_sketch<W>::get_upper_bound(uint64_t item) const {return get_upper_bound(&item, sizeof(item));} +template<typename W, typename A> +W count_min_sketch<W,A>::get_upper_bound(uint64_t item) const {return get_upper_bound(&item, sizeof(item));} -template<typename W> -W count_min_sketch<W>::get_upper_bound(int64_t item) const {return get_upper_bound(&item, sizeof(item));} +template<typename W, typename A> +W count_min_sketch<W,A>::get_upper_bound(int64_t item) const {return get_upper_bound(&item, sizeof(item));} -template<typename W> -W count_min_sketch<W>::get_upper_bound(const std::string& item) const { +template<typename W, typename A> +W count_min_sketch<W,A>::get_upper_bound(const std::string& item) const { if (item.empty()) return 0 ; // Empty strings are not inserted into the sketch. return get_upper_bound(item.c_str(), item.length()); } -template<typename W> -W count_min_sketch<W>::get_upper_bound(const void* item, size_t size) const { +template<typename W, typename A> +W count_min_sketch<W,A>::get_upper_bound(const void* item, size_t size) const { return get_estimate(item, size) + get_relative_error()*get_total_weight() ; } -template<typename W> -W count_min_sketch<W>::get_lower_bound(uint64_t item) const {return get_lower_bound(&item, sizeof(item));} +template<typename W, typename A> +W count_min_sketch<W,A>::get_lower_bound(uint64_t item) const {return get_lower_bound(&item, sizeof(item));} -template<typename W> -W count_min_sketch<W>::get_lower_bound(int64_t item) const {return get_lower_bound(&item, sizeof(item));} +template<typename W, typename A> +W count_min_sketch<W,A>::get_lower_bound(int64_t item) const {return get_lower_bound(&item, sizeof(item));} -template<typename W> -W count_min_sketch<W>::get_lower_bound(const std::string& item) const { +template<typename W, typename A> +W count_min_sketch<W,A>::get_lower_bound(const std::string& item) const { if (item.empty()) return 0 ; // Empty strings are not inserted into the sketch. return get_lower_bound(item.c_str(), item.length()); } -template<typename W> -W count_min_sketch<W>::get_lower_bound(const void* item, size_t size) const { +template<typename W, typename A> +W count_min_sketch<W,A>::get_lower_bound(const void* item, size_t size) const { return get_estimate(item, size) ; } -template<typename W> -void count_min_sketch<W>::merge(const count_min_sketch<W> &other_sketch){ +template<typename W, typename A> +void count_min_sketch<W,A>::merge(const count_min_sketch &other_sketch){ /* * Merges this sketch into other_sketch sketch by elementwise summing of buckets */ @@ -255,29 +274,21 @@ void count_min_sketch<W>::merge(const count_min_sketch<W> &other_sketch){ } // Iterators -template<typename W> -typename count_min_sketch<W>::const_iterator count_min_sketch<W>::begin() const { +template<typename W, typename A> +typename count_min_sketch<W,A>::const_iterator count_min_sketch<W,A>::begin() const { return _sketch_array.begin(); } -template<typename W> -typename count_min_sketch<W>::const_iterator count_min_sketch<W>::end() const { +template<typename W, typename A> +typename count_min_sketch<W,A>::const_iterator count_min_sketch<W,A>::end() const { return _sketch_array.end(); } -template<typename W> -void count_min_sketch<W>::serialize(std::ostream& os) const { - // Variable table bytes is used to determine how many bytes to allocate for the sketch table. - // We assume that 8 bytes are necessary per entry in the table. - // The extra 1 is for the total_weight variable which will be zero iff the sketch is empty. - // Hence, table_bytes == 0 iff sketch is empty <=> preamble_longs == 1 - const size_t table_bytes(is_empty() ? 0 : (1 + _num_hashes) * _num_buckets); - const size_t size = sizeof(uint64_t) * (2 + table_bytes); - vector_bytes bytes(size, 0); - +template<typename W, typename A> +void count_min_sketch<W,A>::serialize(std::ostream& os) const { // Long 0 - // The first 4 (of 8) bytes are either 1 or 0 (denoting empty vs non-empty) and the final 4 bytes are unused. - const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_SHORT : PREAMBLE_LONGS_FULL; + //const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_SHORT : PREAMBLE_LONGS_FULL; + const uint8_t preamble_longs = PREAMBLE_LONGS_SHORT; const uint8_t ser_ver = SERIAL_VERSION_1; const uint8_t family_id = FAMILY_ID ; const uint8_t flags_byte = (is_empty() ? 1 << flags::IS_EMPTY : 0); @@ -311,8 +322,8 @@ void count_min_sketch<W>::serialize(std::ostream& os) const { } } -template<typename W> -count_min_sketch<W> count_min_sketch<W>::deserialize(std::istream& is, uint64_t seed) { +template<typename W, typename A> +auto count_min_sketch<W,A>::deserialize(std::istream& is, uint64_t seed, const A& allocator) -> count_min_sketch { // First 8 bytes are 4 bytes of preamble and 4 unused bytes. const auto preamble_longs = read<uint8_t>(is) ; @@ -333,7 +344,7 @@ count_min_sketch<W> count_min_sketch<W>::deserialize(std::istream& is, uint64_t throw std::invalid_argument("Incompatible seed hashes: " + std::to_string(seed_hash) + ", " + std::to_string(compute_seed_hash(seed))); } - count_min_sketch<W> c(nhashes, nbuckets, seed) ; + count_min_sketch c(nhashes, nbuckets, seed, allocator) ; const bool is_empty = (flags_byte & (1 << flags::IS_EMPTY)) > 0; if (is_empty == 1) return c ; // sketch is empty, no need to read further. @@ -345,24 +356,23 @@ count_min_sketch<W> count_min_sketch<W>::deserialize(std::istream& is, uint64_t return c ; } -template<typename W> -auto count_min_sketch<W>::serialize(unsigned header_size_bytes) const -> vector_bytes { +template<typename W, typename A> +size_t count_min_sketch<W,A>::get_serialized_size_bytes() const { + // The header is always 2 bytes, whether empty or full + size_t preamble_longs = PREAMBLE_LONGS_SHORT; - // The first 4 (of 8) bytes are either 1 or 0 (denoting empty vs non-empty) and the final 4 bytes are unused. - const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_SHORT : PREAMBLE_LONGS_FULL; + // If the sketch is empty, we're done. Otherwise, we need the total weight + // held by the sketch as well as a data table of size (num_buckets * num_hashes) + return preamble_longs + (is_empty() ? 0 : sizeof(W) * (1 + _num_buckets * _num_hashes)); +} - // Variable table bytes is used to determine how many bytes to allocate for the sketch table. - // We assume that 8 bytes are necessary per entry in the table. - // The extra 1 is for the total_weight variable which will be zero iff the sketch is empty. - // Hence, table_bytes == 0 iff sketch is empty <=> preamble_longs == 1 - const size_t table_bytes(is_empty() ? 0 : (1 + _num_hashes) * _num_buckets); - const size_t size = header_size_bytes + sizeof(uint64_t) * (2 + table_bytes); - vector_bytes bytes(size, 0); +template<typename W, typename A> +auto count_min_sketch<W,A>::serialize(unsigned header_size_bytes) const -> vector_bytes { + vector_bytes bytes(header_size_bytes + get_serialized_size_bytes(), 0, _allocator); uint8_t *ptr = bytes.data() + header_size_bytes; - //std::cout<< "Preamble Long: " << preamble_longs << std::endl; - //std::cout<< "Writing " << size << " bytes." << std::endl; // Long 0 + const uint8_t preamble_longs = PREAMBLE_LONGS_SHORT; ptr += copy_to_mem(preamble_longs, ptr); const uint8_t ser_ver = SERIAL_VERSION_1; ptr += copy_to_mem(ser_ver, ptr); @@ -398,8 +408,8 @@ auto count_min_sketch<W>::serialize(unsigned header_size_bytes) const -> vector_ return bytes; } -template<typename W> -count_min_sketch<W> count_min_sketch<W>::deserialize(const void* bytes, size_t size, uint64_t seed) { +template<typename W, typename A> +auto count_min_sketch<W,A>::deserialize(const void* bytes, size_t size, uint64_t seed, const A& allocator) -> count_min_sketch { const char* ptr = static_cast<const char*>(bytes); // First 8 bytes are 4 bytes of preamble and 4 unused bytes. @@ -428,7 +438,7 @@ count_min_sketch<W> count_min_sketch<W>::deserialize(const void* bytes, size_t s throw std::invalid_argument("Incompatible seed hashes: " + std::to_string(seed_hash) + ", " + std::to_string(compute_seed_hash(seed))); } - count_min_sketch<W> c(nhashes, nbuckets, seed) ; + count_min_sketch c(nhashes, nbuckets, seed, allocator) ; const bool is_empty = (flags_byte & (1 << flags::IS_EMPTY)) > 0; if (is_empty) return c ; // sketch is empty, no need to read further. @@ -444,13 +454,13 @@ count_min_sketch<W> count_min_sketch<W>::deserialize(const void* bytes, size_t s return c; } -template<typename W> -bool count_min_sketch<W>::is_empty() const { +template<typename W, typename A> +bool count_min_sketch<W,A>::is_empty() const { return _total_weight == 0; } -template<typename W> -string<std::allocator<W>> count_min_sketch<W>::to_string() const { +template<typename W, typename A> +string<A> count_min_sketch<W,A>::to_string() const { // count the number of used entries in the sketch uint64_t num_nonzero = 0; for (auto entry : _sketch_array) { @@ -469,20 +479,20 @@ string<std::allocator<W>> count_min_sketch<W>::to_string() const { os << " pct filled : " << std::setprecision(3) << (num_nonzero * 100.0) / _sketch_array.size() << "%" << std::endl; os << "### End sketch summary" << std::endl; - //return string<A>(os.str().c_str(), allocator_); - return string<std::allocator<W>>(os.str().c_str()); + return string<A>(os.str().c_str(), _allocator); } -template<typename W> -void count_min_sketch<W>::check_header_validity(uint8_t preamble_longs, uint8_t serial_version, uint8_t family_id, uint8_t flags_byte) { +template<typename W, typename A> +void count_min_sketch<W,A>::check_header_validity(uint8_t preamble_longs, uint8_t serial_version, uint8_t family_id, uint8_t flags_byte) { const bool empty = (flags_byte & (1 << flags::IS_EMPTY)) > 0; const uint8_t sw = (empty ? 1 : 0) + (2 * serial_version) + (4 * family_id) + (32 * (preamble_longs & 0x3F)); bool valid = true; switch (sw) { // exhaustive list and description of all valid cases + case 70 : break; // !empty, ser_ver==1, family==1, preLongs=2; case 71 : break; // empty, ser_ver==1, family==1, preLongs=2; - case 102 : break ; // !empty, ser_ver==1, family==1, preLongs=3 ; + //case 102 : break ; // !empty, ser_ver==1, family==1, preLongs=3 ; default : // all other case values are invalid valid = false; } @@ -490,14 +500,13 @@ void count_min_sketch<W>::check_header_validity(uint8_t preamble_longs, uint8_t if (!valid) { std::ostringstream os; os << "Possible sketch corruption. Inconsistent state: " - << "preamble_longs = " << preamble_longs + << "preamble_longs = " << static_cast<uint32_t>(preamble_longs) << ", empty = " << (empty ? "true" : "false") - << ", serialization_version = " << serial_version ; + << ", serialization_version = " << static_cast<uint32_t>(serial_version) ; throw std::invalid_argument(os.str()); } } - } /* namespace datasketches */ #endif diff --git a/count/test/CMakeLists.txt b/count/test/CMakeLists.txt index 9d07e83..d81c95e 100644 --- a/count/test/CMakeLists.txt +++ b/count/test/CMakeLists.txt @@ -39,4 +39,5 @@ add_test( target_sources(count_min_test PRIVATE count_min_test.cpp + count_min_allocation_test.cpp ) diff --git a/count/test/count_min_allocation_test.cpp b/count/test/count_min_allocation_test.cpp new file mode 100644 index 0000000..68df534 --- /dev/null +++ b/count/test/count_min_allocation_test.cpp @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <catch2/catch.hpp> +#include <vector> +#include <cstring> +#include <sstream> +#include <fstream> + +#include "count_min.hpp" +#include "common_defs.hpp" +#include "test_allocator.hpp" + +namespace datasketches { + +using count_min_sketch_test_alloc = count_min_sketch<uint64_t, test_allocator<uint64_t>>; +using alloc = test_allocator<uint64_t>; + +TEST_CASE("CountMin sketch test allocator: serialize-deserialize empty", "[cm_sketch_alloc]"){ + test_allocator_total_bytes = 0; + test_allocator_net_allocations = 0; + { + uint8_t n_hashes = 1 ; + uint32_t n_buckets = 5 ; + std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); + count_min_sketch_test_alloc c(n_hashes, n_buckets, DEFAULT_SEED, alloc(0)) ; + c.serialize(s); + count_min_sketch_test_alloc d = count_min_sketch_test_alloc::deserialize(s, DEFAULT_SEED, alloc(0)) ; + REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ; + REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ; + REQUIRE(c.get_seed() == d.get_seed()) ; + uint64_t zero = 0; + REQUIRE(c.get_estimate(zero) == d.get_estimate(zero)) ; + REQUIRE(c.get_total_weight() == d.get_total_weight()) ; + + // Check that all entries are equal and 0 + for(auto di: d){ + REQUIRE(di == 0) ; + } + } + REQUIRE(test_allocator_total_bytes == 0); + REQUIRE(test_allocator_net_allocations == 0); +} + +TEST_CASE("CountMin sketch test allocator: serialize-deserialize non-empty", "[cm_sketch_alloc]"){ + test_allocator_total_bytes = 0; + test_allocator_net_allocations = 0; + { + uint8_t n_hashes = 3 ; + uint32_t n_buckets = 1024 ; + std::stringstream s(std::ios::in | std::ios::out | std::ios::binary); + count_min_sketch_test_alloc c(n_hashes, n_buckets, DEFAULT_SEED, alloc(0)) ; + for(uint64_t i=0 ; i < 10; ++i) c.update(i,10*i*i) ; + c.serialize(s); + count_min_sketch_test_alloc d = count_min_sketch_test_alloc::deserialize(s, DEFAULT_SEED, alloc(0)) ; + REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ; + REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ; + REQUIRE(c.get_seed() == d.get_seed()) ; + REQUIRE(c.get_total_weight() == d.get_total_weight()) ; + for(uint64_t i=0 ; i < 10; ++i){ + REQUIRE(c.get_estimate(i) == d.get_estimate(i)) ; + } + + auto c_it = c.begin() ; + auto d_it = d.begin() ; + while(c_it != c.end()){ + REQUIRE(*c_it == *d_it) ; + ++c_it ; + ++d_it ; + } + } + REQUIRE(test_allocator_total_bytes == 0); + REQUIRE(test_allocator_net_allocations == 0); +} + +TEST_CASE("CountMin sketch test allocator: bytes serialize-deserialize empty", "[cm_sketch_alloc]"){ + test_allocator_total_bytes = 0; + test_allocator_net_allocations = 0; + { + uint8_t n_hashes = 3 ; + uint32_t n_buckets = 32 ; + count_min_sketch_test_alloc c(n_hashes, n_buckets, DEFAULT_SEED, alloc(0)) ; + auto bytes = c.serialize() ; + + REQUIRE_THROWS_AS(count_min_sketch_test_alloc::deserialize(bytes.data(), bytes.size(), DEFAULT_SEED-1, alloc(0)), std::invalid_argument); + auto d = count_min_sketch_test_alloc::deserialize(bytes.data(), bytes.size(), DEFAULT_SEED, alloc(0)) ; + REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ; + REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ; + REQUIRE(c.get_seed() == d.get_seed()) ; + uint64_t zero = 0; + REQUIRE(c.get_estimate(zero) == d.get_estimate(zero)) ; + REQUIRE(c.get_total_weight() == d.get_total_weight()) ; + + // Check that all entries are equal and 0 + for(auto di: d){ + REQUIRE(di == 0) ; + } + } + REQUIRE(test_allocator_total_bytes == 0); + REQUIRE(test_allocator_net_allocations == 0); +} + +TEST_CASE("CountMin sketch test allocator: bytes serialize-deserialize non-empty", "[cm_sketch_alloc]"){ + test_allocator_total_bytes = 0; + test_allocator_net_allocations = 0; + { + uint8_t n_hashes = 5 ; + uint32_t n_buckets = 64 ; + count_min_sketch_test_alloc c(n_hashes, n_buckets, DEFAULT_SEED, alloc(0)) ; + for(uint64_t i=0 ; i < 10; ++i) c.update(i,10*i*i) ; + + auto bytes = c.serialize() ; + REQUIRE_THROWS_AS(count_min_sketch_test_alloc::deserialize(bytes.data(), bytes.size(), DEFAULT_SEED-1, alloc(0)), std::invalid_argument); + auto d = count_min_sketch_test_alloc::deserialize(bytes.data(), bytes.size(), DEFAULT_SEED, alloc(0)) ; + + REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ; + REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ; + REQUIRE(c.get_seed() == d.get_seed()) ; + REQUIRE(c.get_total_weight() == d.get_total_weight()) ; + + // Check that all entries are equal + auto c_it = c.begin() ; + auto d_it = d.begin() ; + while(c_it != c.end()){ + REQUIRE(*c_it == *d_it) ; + ++c_it ; + ++d_it ; + } + + // Check that the estimates agree + for(uint64_t i=0 ; i < 10; ++i){ + REQUIRE(c.get_estimate(i) == d.get_estimate(i)) ; + } + } + REQUIRE(test_allocator_total_bytes == 0); + REQUIRE(test_allocator_net_allocations == 0); +} + +} // namespace datasketches diff --git a/count/test/count_min_test.cpp b/count/test/count_min_test.cpp index 4caaa43..2c2afc1 100644 --- a/count/test/count_min_test.cpp +++ b/count/test/count_min_test.cpp @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + #include <catch2/catch.hpp> #include <vector> #include <cstring> @@ -190,7 +209,8 @@ TEST_CASE("CountMin sketch: serialize-deserialize empty", "[cm_sketch]"){ REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ; REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ; REQUIRE(c.get_seed() == d.get_seed()) ; - REQUIRE(c.get_estimate(0) == d.get_estimate(0)) ; + uint64_t zero = 0; + REQUIRE(c.get_estimate(zero) == d.get_estimate(zero)) ; REQUIRE(c.get_total_weight() == d.get_total_weight()) ; // Check that all entries are equal and 0 @@ -240,7 +260,8 @@ TEST_CASE("CountMin sketch: bytes serialize-deserialize empty", "[cm_sketch]"){ REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ; REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ; REQUIRE(c.get_seed() == d.get_seed()) ; - REQUIRE(c.get_estimate(0) == d.get_estimate(0)) ; + uint64_t zero = 0; + REQUIRE(c.get_estimate(zero) == d.get_estimate(zero)) ; REQUIRE(c.get_total_weight() == d.get_total_weight()) ; // Check that all entries are equal and 0 @@ -281,6 +302,5 @@ TEST_CASE("CountMin sketch: bytes serialize-deserialize non-empty", "[cm_sketch] } - } /* namespace datasketches */ --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
