This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch count_min_python
in repository https://gitbox.apache.org/repos/asf/datasketches-cpp.git

commit c8cf06f434bd20bacf149f7793a8e4b7df6f45f0
Author: Jon Malkin <[email protected]>
AuthorDate: Thu Apr 6 00:17:33 2023 -0700

    Add apache license header to files. Add allocator support, clean up 
serialized size
---
 count/include/count_min.hpp              |  58 ++++++--
 count/include/count_min_impl.hpp         | 223 ++++++++++++++++---------------
 count/test/CMakeLists.txt                |   1 +
 count/test/count_min_allocation_test.cpp | 155 +++++++++++++++++++++
 count/test/count_min_test.cpp            |  26 +++-
 5 files changed, 344 insertions(+), 119 deletions(-)

diff --git a/count/include/count_min.hpp b/count/include/count_min.hpp
index c4bd752..c54fb25 100644
--- a/count/include/count_min.hpp
+++ b/count/include/count_min.hpp
@@ -1,3 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
 #ifndef COUNT_MIN_HPP_
 #define COUNT_MIN_HPP_
 
@@ -14,12 +33,13 @@ namespace datasketches {
    * @author Charlie Dickens
    */
 
-template<typename W>
+template <typename W,
+          typename Allocator = std::allocator<W>>
 class count_min_sketch{
   static_assert(std::is_arithmetic<W>::value, "Arithmetic type expected");
 public:
+  using allocator_type = Allocator;
 
-  using vector_bytes = std::vector<uint8_t>;
   /**
    * Creates an instance of the sketch given parameters _num_hashes, 
_num_buckets and hash seed, `seed`.
    * @param num_hashes : number of hash functions in the sketch. Equivalently 
the number of rows in the array
@@ -29,7 +49,7 @@ public:
    * The items inserted into the sketch can be arbitrary type, so long as they 
are hashable via murmurhash.
    * Only update and estimate methods are added for uint64_t and string types.
    */
-  count_min_sketch(uint8_t num_hashes, uint32_t num_buckets, uint64_t seed = 
DEFAULT_SEED) ;
+  count_min_sketch(uint8_t num_hashes, uint32_t num_buckets, uint64_t seed = 
DEFAULT_SEED, const Allocator& allocator = Allocator()) ;
 
   /**
    * @return configured _num_hashes of this sketch
@@ -169,7 +189,7 @@ public:
   /*
    * merges a separate count_min_sketch into this count_min_sketch.
    */
-  void merge(const count_min_sketch<W> &other_sketch) ;
+  void merge(const count_min_sketch &other_sketch) ;
 
   /**
    * Returns true if this sketch is empty.
@@ -183,7 +203,7 @@ public:
    * @brief Returns a string describing the sketch
    * @return A string with a human-readable description of the sketch
    */
-  string<std::allocator<W>> to_string() const;
+  string<Allocator> to_string() const;
 
   // Iterators
   using const_iterator = typename std::vector<W>::const_iterator ;
@@ -238,8 +258,22 @@ public:
    *
    */
 
+  
+  /**
+   * Computes size needed to serialize the current state of the sketch.
+   * @return size in bytes needed to serialize this sketch
+   */
+  size_t get_serialized_size_bytes() const;
+
+  /**
+   * This method serializes a binary image of the sketch to an output stream.
+   */
   void serialize(std::ostream& os) const;
 
+  // This is a convenience alias for users
+  // The type returned by the following serialize method
+  using vector_bytes = std::vector<uint8_t, typename 
std::allocator_traits<Allocator>::template rebind_alloc<uint8_t>>;
+
   /**
    * This method serializes the sketch as a vector of bytes.
    * An optional header can be reserved in front of the sketch.
@@ -255,8 +289,7 @@ public:
   * @param seed the seed for the hash function that was used to create the 
sketch
   * @return an instance of a sketch
   */
-  //static count_min_sketch deserialize(std::istream& is, uint64_t 
seed=DEFAULT_SEED) const;
-  static count_min_sketch deserialize(std::istream& is, uint64_t seed) ;
+  static count_min_sketch deserialize(std::istream& is, uint64_t 
seed=DEFAULT_SEED, const Allocator& allocator = Allocator());
 
   /**
   * This method deserializes a sketch from a given array of bytes.
@@ -265,12 +298,19 @@ public:
   * @param seed the seed for the hash function that was used to create the 
sketch
   * @return an instance of the sketch
   */
-  static count_min_sketch deserialize(const void* bytes, size_t size, uint64_t 
seed=DEFAULT_SEED);
+  static count_min_sketch deserialize(const void* bytes, size_t size, uint64_t 
seed=DEFAULT_SEED, const Allocator& allocator = Allocator());
+
+  /**
+   * Returns the allocator for this sketch.
+   * @return allocator
+   */
+  allocator_type get_allocator() const;
 
 private:
+  Allocator _allocator;
   uint8_t _num_hashes ;
   uint32_t _num_buckets ;
-  std::vector<W> _sketch_array ; // the array stored by the sketch
+  std::vector<W, Allocator> _sketch_array ; // the array stored by the sketch
   uint64_t _seed ;
   W _total_weight ;
   std::vector<uint64_t> hash_seeds ;
diff --git a/count/include/count_min_impl.hpp b/count/include/count_min_impl.hpp
index 8a1551f..7f414a5 100644
--- a/count/include/count_min_impl.hpp
+++ b/count/include/count_min_impl.hpp
@@ -1,3 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
 #ifndef COUNT_MIN_IMPL_HPP_
 #define COUNT_MIN_IMPL_HPP_
 
@@ -11,11 +30,12 @@
 
 namespace datasketches {
 
-template<typename W>
-count_min_sketch<W>::count_min_sketch(uint8_t num_hashes, uint32_t 
num_buckets, uint64_t seed):
+template<typename W, typename A>
+count_min_sketch<W,A>::count_min_sketch(uint8_t num_hashes, uint32_t 
num_buckets, uint64_t seed, const A& allocator):
+_allocator(allocator),
 _num_hashes(num_hashes),
 _num_buckets(num_buckets),
-_sketch_array((num_hashes*num_buckets < 1<<30) ? num_hashes*num_buckets : 0, 
0),
+_sketch_array((num_hashes*num_buckets < 1<<30) ? num_hashes*num_buckets : 0, 
0, _allocator),
 _seed(seed),
 _total_weight(0){
   if(num_buckets < 3) throw std::invalid_argument("Using fewer than 3 buckets 
incurs relative error greater than 1.") ;
@@ -27,7 +47,6 @@ _total_weight(0){
                                 "Try reducing either the number of buckets or 
the number of hash functions.") ;
   }
 
-
   std::default_random_engine rng(_seed);
   std::uniform_int_distribution<uint64_t> extra_hash_seeds(0, 
std::numeric_limits<uint64_t>::max());
   hash_seeds.reserve(num_hashes) ;
@@ -37,33 +56,33 @@ _total_weight(0){
   }
 }
 
-template<typename W>
-uint8_t count_min_sketch<W>::get_num_hashes() const{
+template<typename W, typename A>
+uint8_t count_min_sketch<W,A>::get_num_hashes() const{
     return _num_hashes ;
 }
 
-template<typename W>
-uint32_t count_min_sketch<W>::get_num_buckets() const{
+template<typename W, typename A>
+uint32_t count_min_sketch<W,A>::get_num_buckets() const{
     return _num_buckets ;
 }
 
-template<typename W>
-uint64_t count_min_sketch<W>::get_seed() const {
+template<typename W, typename A>
+uint64_t count_min_sketch<W,A>::get_seed() const {
     return _seed ;
 }
 
-template<typename W>
-double count_min_sketch<W>::get_relative_error() const{
+template<typename W, typename A>
+double count_min_sketch<W,A>::get_relative_error() const{
   return exp(1.0) / double(_num_buckets) ;
 }
 
-template<typename W>
-W count_min_sketch<W>::get_total_weight() const{
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_total_weight() const{
   return _total_weight ;
 }
 
-template<typename W>
-uint32_t count_min_sketch<W>::suggest_num_buckets(double relative_error){
+template<typename W, typename A>
+uint32_t count_min_sketch<W,A>::suggest_num_buckets(double relative_error){
   /*
    * Function to help users select a number of buckets for a given error.
    * TODO: Change this when we use only power of 2 buckets.
@@ -75,8 +94,8 @@ uint32_t count_min_sketch<W>::suggest_num_buckets(double 
relative_error){
   return ceil(exp(1.0) / relative_error);
 }
 
-template<typename W>
-uint8_t count_min_sketch<W>::suggest_num_hashes(double confidence){
+template<typename W, typename A>
+uint8_t count_min_sketch<W,A>::suggest_num_hashes(double confidence){
   /*
    * Function to help users select a number of hashes for a given confidence
    * e.g. confidence = 1 - failure probability
@@ -88,8 +107,8 @@ uint8_t count_min_sketch<W>::suggest_num_hashes(double 
confidence){
   return std::min<uint8_t>( ceil(log(1.0/(1.0 - confidence))), UINT8_MAX) ;
 }
 
-template<typename W>
-std::vector<uint64_t> count_min_sketch<W>::get_hashes(const void* item, size_t 
size) const{
+template<typename W, typename A>
+std::vector<uint64_t> count_min_sketch<W,A>::get_hashes(const void* item, 
size_t size) const{
   /*
    * Returns the hash locations for the input item using the original hashing
    * scheme from [1].
@@ -120,20 +139,20 @@ std::vector<uint64_t> 
count_min_sketch<W>::get_hashes(const void* item, size_t s
   return sketch_update_locations ;
 }
 
-template<typename W>
-W count_min_sketch<W>::get_estimate(uint64_t item) const {return 
get_estimate(&item, sizeof(item));}
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_estimate(uint64_t item) const {return 
get_estimate(&item, sizeof(item));}
 
-template<typename W>
-W count_min_sketch<W>::get_estimate(int64_t item) const {return 
get_estimate(&item, sizeof(item));}
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_estimate(int64_t item) const {return 
get_estimate(&item, sizeof(item));}
 
-template<typename W>
-W count_min_sketch<W>::get_estimate(const std::string& item) const {
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_estimate(const std::string& item) const {
   if (item.empty()) return 0 ; // Empty strings are not inserted into the 
sketch.
   return get_estimate(item.c_str(), item.length());
 }
 
-template<typename W>
-W count_min_sketch<W>::get_estimate(const void* item, size_t size) const {
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_estimate(const void* item, size_t size) const {
   /*
    * Returns the estimated frequency of the item
    */
@@ -146,40 +165,40 @@ W count_min_sketch<W>::get_estimate(const void* item, 
size_t size) const {
   return result ;
 }
 
-template<typename W>
-void count_min_sketch<W>::update(uint64_t item, W weight) {
+template<typename W, typename A>
+void count_min_sketch<W,A>::update(uint64_t item, W weight) {
   update(&item, sizeof(item), weight);
 }
 
-template<typename W>
-void count_min_sketch<W>::update(uint64_t item) {
+template<typename W, typename A>
+void count_min_sketch<W,A>::update(uint64_t item) {
   update(&item, sizeof(item), 1);
 }
 
-template<typename W>
-void count_min_sketch<W>::update(int64_t item, W weight) {
+template<typename W, typename A>
+void count_min_sketch<W,A>::update(int64_t item, W weight) {
   update(&item, sizeof(item), weight);
 }
 
-template<typename W>
-void count_min_sketch<W>::update(int64_t item) {
+template<typename W, typename A>
+void count_min_sketch<W,A>::update(int64_t item) {
   update(&item, sizeof(item), 1);
 }
 
-template<typename W>
-void count_min_sketch<W>::update(const std::string& item, W weight) {
+template<typename W, typename A>
+void count_min_sketch<W,A>::update(const std::string& item, W weight) {
   if (item.empty()) return;
   update(item.c_str(), item.length(), weight);
 }
 
-template<typename W>
-void count_min_sketch<W>::update(const std::string& item) {
+template<typename W, typename A>
+void count_min_sketch<W,A>::update(const std::string& item) {
   if (item.empty()) return;
   update(item.c_str(), item.length(), 1);
 }
 
-template<typename W>
-void count_min_sketch<W>::update(const void* item, size_t size, W weight) {
+template<typename W, typename A>
+void count_min_sketch<W,A>::update(const void* item, size_t size, W weight) {
   /*
    * Gets the item's hash locations and then increments the sketch in those
    * locations by the weight.
@@ -192,42 +211,42 @@ void count_min_sketch<W>::update(const void* item, size_t 
size, W weight) {
   }
 }
 
-template<typename W>
-W count_min_sketch<W>::get_upper_bound(uint64_t item) const {return 
get_upper_bound(&item, sizeof(item));}
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_upper_bound(uint64_t item) const {return 
get_upper_bound(&item, sizeof(item));}
 
-template<typename W>
-W count_min_sketch<W>::get_upper_bound(int64_t item) const {return 
get_upper_bound(&item, sizeof(item));}
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_upper_bound(int64_t item) const {return 
get_upper_bound(&item, sizeof(item));}
 
-template<typename W>
-W count_min_sketch<W>::get_upper_bound(const std::string& item) const {
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_upper_bound(const std::string& item) const {
   if (item.empty()) return 0 ; // Empty strings are not inserted into the 
sketch.
   return get_upper_bound(item.c_str(), item.length());
 }
 
-template<typename W>
-W count_min_sketch<W>::get_upper_bound(const void* item, size_t size) const {
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_upper_bound(const void* item, size_t size) const {
   return get_estimate(item, size) + get_relative_error()*get_total_weight() ;
 }
 
-template<typename W>
-W count_min_sketch<W>::get_lower_bound(uint64_t item) const {return 
get_lower_bound(&item, sizeof(item));}
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_lower_bound(uint64_t item) const {return 
get_lower_bound(&item, sizeof(item));}
 
-template<typename W>
-W count_min_sketch<W>::get_lower_bound(int64_t item) const {return 
get_lower_bound(&item, sizeof(item));}
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_lower_bound(int64_t item) const {return 
get_lower_bound(&item, sizeof(item));}
 
-template<typename W>
-W count_min_sketch<W>::get_lower_bound(const std::string& item) const {
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_lower_bound(const std::string& item) const {
   if (item.empty()) return 0 ; // Empty strings are not inserted into the 
sketch.
   return get_lower_bound(item.c_str(), item.length());
 }
 
-template<typename W>
-W count_min_sketch<W>::get_lower_bound(const void* item, size_t size) const {
+template<typename W, typename A>
+W count_min_sketch<W,A>::get_lower_bound(const void* item, size_t size) const {
   return get_estimate(item, size) ;
 }
 
-template<typename W>
-void count_min_sketch<W>::merge(const count_min_sketch<W> &other_sketch){
+template<typename W, typename A>
+void count_min_sketch<W,A>::merge(const count_min_sketch &other_sketch){
   /*
   * Merges this sketch into other_sketch sketch by elementwise summing of 
buckets
   */
@@ -255,29 +274,21 @@ void count_min_sketch<W>::merge(const count_min_sketch<W> 
&other_sketch){
 }
 
 // Iterators
-template<typename W>
-typename count_min_sketch<W>::const_iterator count_min_sketch<W>::begin() 
const {
+template<typename W, typename A>
+typename count_min_sketch<W,A>::const_iterator count_min_sketch<W,A>::begin() 
const {
   return _sketch_array.begin();
 }
 
-template<typename W>
-typename count_min_sketch<W>::const_iterator count_min_sketch<W>::end() const {
+template<typename W, typename A>
+typename count_min_sketch<W,A>::const_iterator count_min_sketch<W,A>::end() 
const {
 return _sketch_array.end();
 }
 
-template<typename W>
-void count_min_sketch<W>::serialize(std::ostream& os) const {
-  // Variable table bytes is used to determine how many bytes to allocate for 
the sketch table.
-  // We assume that 8 bytes are necessary per entry in the table.
-  // The extra 1 is for the total_weight variable which will be zero iff the 
sketch is empty.
-  // Hence, table_bytes == 0 iff sketch is empty <=> preamble_longs == 1
-  const size_t table_bytes(is_empty() ? 0 : (1 + _num_hashes) * _num_buckets);
-  const size_t size = sizeof(uint64_t) * (2 + table_bytes);
-  vector_bytes bytes(size, 0);
-
+template<typename W, typename A>
+void count_min_sketch<W,A>::serialize(std::ostream& os) const {
   // Long 0
-  // The first 4 (of 8) bytes are either 1 or 0 (denoting empty vs non-empty) 
and the final 4 bytes are unused.
-  const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_SHORT : 
PREAMBLE_LONGS_FULL;
+  //const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_SHORT : 
PREAMBLE_LONGS_FULL;
+  const uint8_t preamble_longs = PREAMBLE_LONGS_SHORT;
   const uint8_t ser_ver = SERIAL_VERSION_1;
   const uint8_t family_id = FAMILY_ID ;
   const uint8_t flags_byte = (is_empty() ? 1 << flags::IS_EMPTY : 0);
@@ -311,8 +322,8 @@ void count_min_sketch<W>::serialize(std::ostream& os) const 
{
   }
 }
 
-template<typename W>
-count_min_sketch<W> count_min_sketch<W>::deserialize(std::istream& is, 
uint64_t seed) {
+template<typename W, typename A>
+auto count_min_sketch<W,A>::deserialize(std::istream& is, uint64_t seed, const 
A& allocator) -> count_min_sketch {
 
   // First 8 bytes are 4 bytes of preamble and 4 unused bytes.
   const auto preamble_longs = read<uint8_t>(is) ;
@@ -333,7 +344,7 @@ count_min_sketch<W> 
count_min_sketch<W>::deserialize(std::istream& is, uint64_t
     throw std::invalid_argument("Incompatible seed hashes: " + 
std::to_string(seed_hash) + ", "
                                 + std::to_string(compute_seed_hash(seed)));
   }
-  count_min_sketch<W> c(nhashes, nbuckets, seed) ;
+  count_min_sketch c(nhashes, nbuckets, seed, allocator) ;
   const bool is_empty = (flags_byte & (1 << flags::IS_EMPTY)) > 0;
   if (is_empty == 1) return c ; // sketch is empty, no need to read further.
 
@@ -345,24 +356,23 @@ count_min_sketch<W> 
count_min_sketch<W>::deserialize(std::istream& is, uint64_t
   return c ;
 }
 
-template<typename W>
-auto count_min_sketch<W>::serialize(unsigned header_size_bytes) const -> 
vector_bytes {
+template<typename W, typename A>
+size_t count_min_sketch<W,A>::get_serialized_size_bytes() const {
+  // The header is always 2 bytes, whether empty or full
+  size_t preamble_longs = PREAMBLE_LONGS_SHORT;
 
-  // The first 4 (of 8) bytes are either 1 or 0 (denoting empty vs non-empty) 
and the final 4 bytes are unused.
-  const uint8_t preamble_longs = is_empty() ? PREAMBLE_LONGS_SHORT : 
PREAMBLE_LONGS_FULL;
+  // If the sketch is empty, we're done. Otherwise, we need the total weight
+  // held by the sketch as well as a data table of size (num_buckets * 
num_hashes)
+  return preamble_longs + (is_empty() ? 0 : sizeof(W) * (1 + _num_buckets * 
_num_hashes));
+}
 
-  // Variable table bytes is used to determine how many bytes to allocate for 
the sketch table.
-  // We assume that 8 bytes are necessary per entry in the table.
-  // The extra 1 is for the total_weight variable which will be zero iff the 
sketch is empty.
-  // Hence, table_bytes == 0 iff sketch is empty <=> preamble_longs == 1
-  const size_t table_bytes(is_empty() ? 0 : (1 + _num_hashes) * _num_buckets);
-  const size_t size = header_size_bytes + sizeof(uint64_t) * (2 + table_bytes);
-  vector_bytes bytes(size, 0);
+template<typename W, typename A>
+auto count_min_sketch<W,A>::serialize(unsigned header_size_bytes) const -> 
vector_bytes {
+  vector_bytes bytes(header_size_bytes + get_serialized_size_bytes(), 0, 
_allocator);
   uint8_t *ptr = bytes.data() + header_size_bytes;
-  //std::cout<< "Preamble Long: " << preamble_longs << std::endl;
-  //std::cout<< "Writing " << size << " bytes." << std::endl;
 
   // Long 0
+  const uint8_t preamble_longs = PREAMBLE_LONGS_SHORT;
   ptr += copy_to_mem(preamble_longs, ptr);
   const uint8_t ser_ver = SERIAL_VERSION_1;
   ptr += copy_to_mem(ser_ver, ptr);
@@ -398,8 +408,8 @@ auto count_min_sketch<W>::serialize(unsigned 
header_size_bytes) const -> vector_
   return bytes;
 }
 
-template<typename W>
-count_min_sketch<W> count_min_sketch<W>::deserialize(const void* bytes, size_t 
size, uint64_t seed) {
+template<typename W, typename A>
+auto count_min_sketch<W,A>::deserialize(const void* bytes, size_t size, 
uint64_t seed, const A& allocator) -> count_min_sketch {
   const char* ptr = static_cast<const char*>(bytes);
 
   // First 8 bytes are 4 bytes of preamble and 4 unused bytes.
@@ -428,7 +438,7 @@ count_min_sketch<W> count_min_sketch<W>::deserialize(const 
void* bytes, size_t s
     throw std::invalid_argument("Incompatible seed hashes: " + 
std::to_string(seed_hash) + ", "
                                 + std::to_string(compute_seed_hash(seed)));
   }
-  count_min_sketch<W> c(nhashes, nbuckets, seed) ;
+  count_min_sketch c(nhashes, nbuckets, seed, allocator) ;
   const bool is_empty = (flags_byte & (1 << flags::IS_EMPTY)) > 0;
   if (is_empty) return c ; // sketch is empty, no need to read further.
 
@@ -444,13 +454,13 @@ count_min_sketch<W> 
count_min_sketch<W>::deserialize(const void* bytes, size_t s
   return c;
 }
 
-template<typename W>
-bool count_min_sketch<W>::is_empty() const {
+template<typename W, typename A>
+bool count_min_sketch<W,A>::is_empty() const {
   return _total_weight == 0;
 }
 
-template<typename W>
-string<std::allocator<W>> count_min_sketch<W>::to_string() const {
+template<typename W, typename A>
+string<A> count_min_sketch<W,A>::to_string() const {
   // count the number of used entries in the sketch
   uint64_t num_nonzero = 0;
   for (auto entry : _sketch_array) {
@@ -469,20 +479,20 @@ string<std::allocator<W>> 
count_min_sketch<W>::to_string() const {
   os << "   pct filled     : " << std::setprecision(3) << (num_nonzero * 
100.0) / _sketch_array.size() << "%" << std::endl;
   os << "### End sketch summary" << std::endl;
 
-  //return string<A>(os.str().c_str(), allocator_);
-  return string<std::allocator<W>>(os.str().c_str());
+  return string<A>(os.str().c_str(), _allocator);
 }
 
-template<typename W>
-void count_min_sketch<W>::check_header_validity(uint8_t preamble_longs, 
uint8_t serial_version,  uint8_t family_id, uint8_t flags_byte) {
+template<typename W, typename A>
+void count_min_sketch<W,A>::check_header_validity(uint8_t preamble_longs, 
uint8_t serial_version,  uint8_t family_id, uint8_t flags_byte) {
   const bool empty = (flags_byte & (1 << flags::IS_EMPTY)) > 0;
 
   const uint8_t sw = (empty ? 1 : 0) + (2 * serial_version) + (4 * family_id) 
+ (32 * (preamble_longs & 0x3F));
   bool valid = true;
 
   switch (sw) { // exhaustive list and description of all valid cases
+    case 70 : break; // !empty, ser_ver==1, family==1, preLongs=2;
     case 71 : break; // empty, ser_ver==1, family==1, preLongs=2;
-    case 102 : break ; // !empty, ser_ver==1, family==1, preLongs=3 ;
+    //case 102 : break ; // !empty, ser_ver==1, family==1, preLongs=3 ;
     default : // all other case values are invalid
       valid = false;
   }
@@ -490,14 +500,13 @@ void count_min_sketch<W>::check_header_validity(uint8_t 
preamble_longs, uint8_t
   if (!valid) {
     std::ostringstream os;
     os << "Possible sketch corruption. Inconsistent state: "
-       << "preamble_longs = " << preamble_longs
+       << "preamble_longs = " << static_cast<uint32_t>(preamble_longs)
        << ", empty = " << (empty ? "true" : "false")
-       << ", serialization_version = " << serial_version ;
+       << ", serialization_version = " << 
static_cast<uint32_t>(serial_version) ;
     throw std::invalid_argument(os.str());
   }
 }
 
-
 } /* namespace datasketches */
 
 #endif
diff --git a/count/test/CMakeLists.txt b/count/test/CMakeLists.txt
index 9d07e83..d81c95e 100644
--- a/count/test/CMakeLists.txt
+++ b/count/test/CMakeLists.txt
@@ -39,4 +39,5 @@ add_test(
 target_sources(count_min_test
         PRIVATE
         count_min_test.cpp
+        count_min_allocation_test.cpp
         )
diff --git a/count/test/count_min_allocation_test.cpp 
b/count/test/count_min_allocation_test.cpp
new file mode 100644
index 0000000..68df534
--- /dev/null
+++ b/count/test/count_min_allocation_test.cpp
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <catch2/catch.hpp>
+#include <vector>
+#include <cstring>
+#include <sstream>
+#include <fstream>
+
+#include "count_min.hpp"
+#include "common_defs.hpp"
+#include "test_allocator.hpp"
+
+namespace datasketches {
+
+using count_min_sketch_test_alloc = count_min_sketch<uint64_t, 
test_allocator<uint64_t>>;
+using alloc = test_allocator<uint64_t>;
+
+TEST_CASE("CountMin sketch test allocator: serialize-deserialize empty", 
"[cm_sketch_alloc]"){
+  test_allocator_total_bytes = 0;
+  test_allocator_net_allocations = 0;
+  {
+    uint8_t n_hashes = 1 ;
+    uint32_t n_buckets = 5 ;
+    std::stringstream s(std::ios::in | std::ios::out | std::ios::binary);
+    count_min_sketch_test_alloc c(n_hashes, n_buckets, DEFAULT_SEED, alloc(0)) 
;
+    c.serialize(s);
+    count_min_sketch_test_alloc d = 
count_min_sketch_test_alloc::deserialize(s, DEFAULT_SEED, alloc(0)) ;
+    REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ;
+    REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ;
+    REQUIRE(c.get_seed() == d.get_seed()) ;
+    uint64_t zero = 0;
+    REQUIRE(c.get_estimate(zero) == d.get_estimate(zero)) ;
+    REQUIRE(c.get_total_weight() == d.get_total_weight()) ;
+
+    // Check that all entries are equal and 0
+    for(auto di: d){
+      REQUIRE(di == 0) ;
+    }
+  }
+  REQUIRE(test_allocator_total_bytes == 0);
+  REQUIRE(test_allocator_net_allocations == 0);
+}
+
+TEST_CASE("CountMin sketch test allocator: serialize-deserialize non-empty", 
"[cm_sketch_alloc]"){
+  test_allocator_total_bytes = 0;
+  test_allocator_net_allocations = 0;
+  {
+    uint8_t n_hashes = 3 ;
+    uint32_t n_buckets = 1024 ;
+    std::stringstream s(std::ios::in | std::ios::out | std::ios::binary);
+    count_min_sketch_test_alloc c(n_hashes, n_buckets, DEFAULT_SEED, alloc(0)) 
;
+    for(uint64_t i=0 ; i < 10; ++i) c.update(i,10*i*i) ;
+    c.serialize(s);
+    count_min_sketch_test_alloc d = 
count_min_sketch_test_alloc::deserialize(s, DEFAULT_SEED, alloc(0)) ;
+    REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ;
+    REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ;
+    REQUIRE(c.get_seed() == d.get_seed()) ;
+    REQUIRE(c.get_total_weight() == d.get_total_weight()) ;
+    for(uint64_t i=0 ; i < 10; ++i){
+      REQUIRE(c.get_estimate(i) == d.get_estimate(i)) ;
+    }
+
+    auto c_it = c.begin() ;
+    auto d_it = d.begin() ;
+    while(c_it != c.end()){
+      REQUIRE(*c_it == *d_it) ;
+      ++c_it ;
+      ++d_it ;
+    }
+  }
+  REQUIRE(test_allocator_total_bytes == 0);
+  REQUIRE(test_allocator_net_allocations == 0);
+}
+
+TEST_CASE("CountMin sketch test allocator: bytes serialize-deserialize empty", 
"[cm_sketch_alloc]"){
+  test_allocator_total_bytes = 0;
+  test_allocator_net_allocations = 0;
+  {
+    uint8_t n_hashes = 3 ;
+    uint32_t n_buckets = 32 ;
+    count_min_sketch_test_alloc c(n_hashes, n_buckets, DEFAULT_SEED, alloc(0)) 
;
+    auto bytes = c.serialize() ;
+
+    REQUIRE_THROWS_AS(count_min_sketch_test_alloc::deserialize(bytes.data(), 
bytes.size(), DEFAULT_SEED-1, alloc(0)), std::invalid_argument);
+    auto d = count_min_sketch_test_alloc::deserialize(bytes.data(), 
bytes.size(), DEFAULT_SEED, alloc(0)) ;
+    REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ;
+    REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ;
+    REQUIRE(c.get_seed() == d.get_seed()) ;
+    uint64_t zero = 0;
+    REQUIRE(c.get_estimate(zero) == d.get_estimate(zero)) ;
+    REQUIRE(c.get_total_weight() == d.get_total_weight()) ;
+
+    // Check that all entries are equal and 0
+    for(auto di: d){
+      REQUIRE(di == 0) ;
+    }
+  }
+  REQUIRE(test_allocator_total_bytes == 0);
+  REQUIRE(test_allocator_net_allocations == 0);
+}
+
+TEST_CASE("CountMin sketch test allocator: bytes serialize-deserialize 
non-empty", "[cm_sketch_alloc]"){
+  test_allocator_total_bytes = 0;
+  test_allocator_net_allocations = 0;
+  {
+    uint8_t n_hashes = 5 ;
+    uint32_t n_buckets = 64 ;
+    count_min_sketch_test_alloc c(n_hashes, n_buckets, DEFAULT_SEED, alloc(0)) 
;
+    for(uint64_t i=0 ; i < 10; ++i) c.update(i,10*i*i) ;
+
+    auto bytes = c.serialize() ;
+    REQUIRE_THROWS_AS(count_min_sketch_test_alloc::deserialize(bytes.data(), 
bytes.size(), DEFAULT_SEED-1, alloc(0)), std::invalid_argument);
+    auto d = count_min_sketch_test_alloc::deserialize(bytes.data(), 
bytes.size(), DEFAULT_SEED, alloc(0)) ;
+
+    REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ;
+    REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ;
+    REQUIRE(c.get_seed() == d.get_seed()) ;
+    REQUIRE(c.get_total_weight() == d.get_total_weight()) ;
+
+    // Check that all entries are equal
+    auto c_it = c.begin() ;
+    auto d_it = d.begin() ;
+    while(c_it != c.end()){
+      REQUIRE(*c_it == *d_it) ;
+      ++c_it ;
+      ++d_it ;
+    }
+
+    // Check that the estimates agree
+    for(uint64_t i=0 ; i < 10; ++i){
+      REQUIRE(c.get_estimate(i) == d.get_estimate(i)) ;
+    }
+  }
+  REQUIRE(test_allocator_total_bytes == 0);
+  REQUIRE(test_allocator_net_allocations == 0);
+}
+
+} // namespace datasketches
diff --git a/count/test/count_min_test.cpp b/count/test/count_min_test.cpp
index 4caaa43..2c2afc1 100644
--- a/count/test/count_min_test.cpp
+++ b/count/test/count_min_test.cpp
@@ -1,3 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
 #include <catch2/catch.hpp>
 #include <vector>
 #include <cstring>
@@ -190,7 +209,8 @@ TEST_CASE("CountMin sketch: serialize-deserialize empty", 
"[cm_sketch]"){
     REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ;
     REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ;
     REQUIRE(c.get_seed() == d.get_seed()) ;
-    REQUIRE(c.get_estimate(0) == d.get_estimate(0)) ;
+    uint64_t zero = 0;
+    REQUIRE(c.get_estimate(zero) == d.get_estimate(zero)) ;
     REQUIRE(c.get_total_weight() == d.get_total_weight()) ;
 
     // Check that all entries are equal and 0
@@ -240,7 +260,8 @@ TEST_CASE("CountMin sketch: bytes serialize-deserialize 
empty", "[cm_sketch]"){
   REQUIRE(c.get_num_hashes() == d.get_num_hashes()) ;
   REQUIRE(c.get_num_buckets() == d.get_num_buckets()) ;
   REQUIRE(c.get_seed() == d.get_seed()) ;
-  REQUIRE(c.get_estimate(0) == d.get_estimate(0)) ;
+  uint64_t zero = 0;
+  REQUIRE(c.get_estimate(zero) == d.get_estimate(zero)) ;
   REQUIRE(c.get_total_weight() == d.get_total_weight()) ;
 
   // Check that all entries are equal and 0
@@ -281,6 +302,5 @@ TEST_CASE("CountMin sketch: bytes serialize-deserialize 
non-empty", "[cm_sketch]
 
 }
 
-
 } /* namespace datasketches */
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to