This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch vo_exception_safety
in repository https://gitbox.apache.org/repos/asf/incubator-datasketches-cpp.git

commit cd747bf79cf3ba7ace6bb5f872554fff12bcd51d
Author: Jon Malkin <[email protected]>
AuthorDate: Fri May 8 01:21:37 2020 -0700

    add exception handling to varopt deserialize, untested
---
 common/include/memory_operations.hpp     |   6 +-
 sampling/include/var_opt_sketch.hpp      |  17 +-
 sampling/include/var_opt_sketch_impl.hpp | 256 ++++++++++++++++++++++++++++---
 3 files changed, 250 insertions(+), 29 deletions(-)

diff --git a/common/include/memory_operations.hpp 
b/common/include/memory_operations.hpp
index 6452721..80dc3a3 100644
--- a/common/include/memory_operations.hpp
+++ b/common/include/memory_operations.hpp
@@ -17,8 +17,8 @@
  * under the License.
  */
 
-#ifndef _MEMORY_CHECKS_HPP_
-#define _MEMORY_CHECKS_HPP_
+#ifndef _MEMORY_OPERATIONS_HPP_
+#define _MEMORY_OPERATIONS_HPP_
 
 #include <memory>
 #include <exception>
@@ -54,4 +54,4 @@ static inline size_t copy_to_mem(const void* src, void* dst, 
size_t size) {
 
 } // namespace
 
-#endif // _MEMORY_CHECKS_HPP_
+#endif // _MEMORY_OPERATIONS_HPP_
diff --git a/sampling/include/var_opt_sketch.hpp 
b/sampling/include/var_opt_sketch.hpp
index d2f7e9b..1357a3b 100644
--- a/sampling/include/var_opt_sketch.hpp
+++ b/sampling/include/var_opt_sketch.hpp
@@ -266,9 +266,18 @@ class var_opt_sketch {
     // occurs and is properly tracked.
     bool* marks_;
 
+    // used during deserialization to avoid memork leaks upon errors
+    class items_deleter;
+    class weights_deleter;
+    class marks_deleter;
+
     var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget);
-    var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t 
preamble_longs, std::istream& is);
-    var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t 
preamble_longs, const void* bytes, size_t size);
+    var_opt_sketch(uint32_t k, uint32_t h, uint32_t m, uint32_t r, uint64_t n, 
double total_wt_r, resize_factor rf,
+                   uint32_t curr_items_alloc, bool filled_data, 
std::unique_ptr<T, items_deleter> items,
+                   std::unique_ptr<double, weights_deleter> weights, uint32_t 
num_marks_in_h,
+                   std::unique_ptr<bool, marks_deleter> marks);
+    //var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t 
preamble_longs, std::istream& is);
+    //var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t 
preamble_longs, const void* bytes, size_t size);
 
     friend class var_opt_union<T,S,A>;
     var_opt_sketch(const var_opt_sketch& other, bool as_sketch, uint64_t 
adjusted_n);
@@ -307,8 +316,8 @@ class var_opt_sketch {
     // validation
     static void check_preamble_longs(uint8_t preamble_longs, uint8_t flags);
     static void check_family_and_serialization_version(uint8_t family_id, 
uint8_t ser_ver);
-    // next method sets current_items_alloc_
-    void validate_and_set_current_size(uint32_t preamble_longs);
+    static uint32_t validate_and_get_target_size(uint32_t preamble_longs, 
uint32_t k, uint64_t n,
+                                                 uint32_t h, uint32_t r, 
resize_factor rf);
 
     // things to move to common and be shared among sketches
     static uint32_t get_adjusted_size(uint32_t max_size, uint32_t 
resize_target);
diff --git a/sampling/include/var_opt_sketch_impl.hpp 
b/sampling/include/var_opt_sketch_impl.hpp
index bf08b91..7b56510 100644
--- a/sampling/include/var_opt_sketch_impl.hpp
+++ b/sampling/include/var_opt_sketch_impl.hpp
@@ -175,6 +175,27 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, 
resize_factor rf, bool is_gadg
 }
 
 template<typename T, typename S, typename A>
+var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, uint32_t h, uint32_t m, 
uint32_t r, uint64_t n, double total_wt_r, resize_factor rf,
+                                      uint32_t curr_items_alloc, bool 
filled_data, std::unique_ptr<T, items_deleter> items,
+                                      std::unique_ptr<double, weights_deleter> 
weights, uint32_t num_marks_in_h,
+                                      std::unique_ptr<bool, marks_deleter> 
marks) :
+  k_(k),
+  h_(h),
+  m_(m),
+  r_(r),
+  n_(n),
+  total_wt_r_(total_wt_r),
+  rf_(rf),
+  curr_items_alloc_(curr_items_alloc),
+  filled_data_(filled_data),
+  data_(items.release()),
+  weights_(weights.release()),
+  num_marks_in_h_(num_marks_in_h),
+  marks_(marks.release())
+{}
+
+
+template<typename T, typename S, typename A>
 var_opt_sketch<T,S,A>::~var_opt_sketch() {
   if (data_ != nullptr) {
     if (filled_data_) {
@@ -242,6 +263,7 @@ var_opt_sketch<T,S,A>& 
var_opt_sketch<T,S,A>::operator=(var_opt_sketch&& other)
   return *this;
 }
 
+/*
 template<typename T, typename S, typename A>
 var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool 
is_gadget, uint8_t preamble_longs, std::istream& is) :
   k_(k), m_(0), rf_(rf) {
@@ -298,6 +320,7 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, 
resize_factor rf, bool is_gadg
   S().deserialize(is, &data_[h_ + 1], r_);
 }
 
+
 template<typename T, typename S, typename A>
 var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool 
is_gadget, uint8_t preamble_longs,
                                       const void* bytes, size_t size) : k_(k), 
m_(0), rf_(rf) {
@@ -360,6 +383,7 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, 
resize_factor rf, bool is_gadg
   ptr += S().deserialize(ptr, end_ptr - ptr, data_, h_);
   ptr += S().deserialize(ptr, end_ptr - ptr, &data_[h_ + 1], r_);
 }
+*/
 
 /*
  * An empty sketch requires 8 bytes.
@@ -569,6 +593,8 @@ template<typename T, typename S, typename A>
 var_opt_sketch<T,S,A> var_opt_sketch<T,S,A>::deserialize(const void* bytes, 
size_t size) {
   ensure_minimum_memory(size, 8);
   const char* ptr = static_cast<const char*>(bytes);
+  const char* base = ptr;
+  const char* end_ptr = ptr + size;
   uint8_t first_byte;
   ptr += copy_from_mem(ptr, &first_byte, sizeof(first_byte));
   uint8_t preamble_longs = first_byte & 0x3f;
@@ -589,7 +615,72 @@ var_opt_sketch<T,S,A> 
var_opt_sketch<T,S,A>::deserialize(const void* bytes, size
   const bool is_empty = flags & EMPTY_FLAG_MASK;
   const bool is_gadget = flags & GADGET_FLAG_MASK;
 
-  return is_empty ? var_opt_sketch<T,S,A>(k, rf, is_gadget) : 
var_opt_sketch<T,S,A>(k, rf, is_gadget, preamble_longs, bytes, size);
+  if (is_empty) {
+    return var_opt_sketch<T,S,A>(k, rf, is_gadget);
+  }
+
+  // second and third prelongs
+  uint64_t n;
+  uint32_t h, r;
+  ptr += copy_from_mem(ptr, &n, sizeof(n));
+  ptr += copy_from_mem(ptr, &h, sizeof(h));
+  ptr += copy_from_mem(ptr, &r, sizeof(r));
+
+  const uint32_t array_size = validate_and_get_target_size(preamble_longs, k, 
n, h, r, rf);
+  
+  // current_items_alloc_ is set but validate R region weight (4th prelong), 
if needed, before allocating
+  double total_wt_r = 0.0;
+  if (preamble_longs == PREAMBLE_LONGS_FULL) {
+    ptr += copy_from_mem(ptr, &total_wt_r, sizeof(total_wt_r));
+    if (std::isnan(total_wt_r) || r == 0 || total_wt_r <= 0.0) {
+      throw std::invalid_argument("Possible corruption: deserializing in full 
mode but r = 0 or invalid R weight. "
+       "Found r = " + std::to_string(r) + ", R region weight = " + 
std::to_string(total_wt_r));
+    }
+  } else {
+    total_wt_r = 0.0;
+  }
+
+  // read the first h_ weights, fill in rest of array with -1.0
+  check_memory_size(ptr - base + (h * sizeof(double)), size);
+  std::unique_ptr<double, weights_deleter> 
weights(AllocDouble().allocate(array_size), weights_deleter(array_size));
+  double* wts = weights.get(); // to avoid lots of .get() calls -- do not 
delete
+  ptr += copy_from_mem(ptr, wts, h * sizeof(double));
+  for (size_t i = 0; i < h; ++i) {
+    if (!(wts[i] > 0.0)) {
+      throw std::invalid_argument("Possible corruption: Non-positive weight 
when deserializing: " + std::to_string(wts[i]));
+    }
+  }
+  std::fill(&wts[h], &wts[array_size], -1.0);
+  
+  // read the first h_ marks as packed bytes iff we have a gadget
+  uint32_t num_marks_in_h = 0;
+  std::unique_ptr<bool, marks_deleter> marks(nullptr, 
marks_deleter(array_size));
+  if (is_gadget) {
+    uint8_t val = 0;
+    marks = std::unique_ptr<bool, 
marks_deleter>(AllocBool().allocate(array_size), marks_deleter(array_size));
+    const size_t size_marks = (h / 8) + (h % 8 > 0 ? 1 : 0);
+    check_memory_size(ptr - base + size_marks, size);
+    for (uint32_t i = 0; i < h; ++i) {
+     if ((i & 0x7) == 0x0) { // should trigger on first iteration
+        ptr += copy_from_mem(ptr, &val, sizeof(val));
+      }
+      marks.get()[i] = ((val >> (i & 0x7)) & 0x1) == 1;
+      num_marks_in_h += (marks.get()[i] ? 1 : 0);
+    }
+  }
+
+  // read the sample items, skipping the gap. Either h_ or r_ may be 0
+  items_deleter deleter(array_size);
+  std::unique_ptr<T, items_deleter> items(A().allocate(array_size), deleter);
+  
+  ptr += S().deserialize(ptr, end_ptr - ptr, items.get(), h);
+  deleter.set_h(h); // serde didn't throw, so the items are now valid
+  
+  ptr += S().deserialize(ptr, end_ptr - ptr, &(items.get()[h + 1]), r);
+  deleter.set_r(r); // serde didn't throw, so the items are now valid
+
+  return var_opt_sketch(k, h, (r > 0 ? 1 : 0), r, n, total_wt_r, rf, 
array_size, false,
+                        std::move(items), std::move(weights), num_marks_in_h, 
std::move(marks));
 }
 
 template<typename T, typename S, typename A>
@@ -613,7 +704,69 @@ var_opt_sketch<T,S,A> 
var_opt_sketch<T,S,A>::deserialize(std::istream& is) {
   const bool is_empty = flags & EMPTY_FLAG_MASK;
   const bool is_gadget = flags & GADGET_FLAG_MASK;
 
-  return is_empty ? var_opt_sketch<T,S,A>(k, rf, is_gadget) : 
var_opt_sketch<T,S,A>(k, rf, is_gadget, preamble_longs, is);
+  if (is_empty) {
+    return var_opt_sketch<T,S,A>(k, rf, is_gadget);
+  }
+
+  // second and third prelongs
+  uint64_t n;
+  uint32_t h, r;
+  is.read((char*)&n, sizeof(n));
+  is.read((char*)&h, sizeof(h));
+  is.read((char*)&r, sizeof(r));
+
+  const uint32_t array_size = validate_and_get_target_size(preamble_longs, k, 
n, h, r, rf);
+
+  // current_items_alloc_ is set but validate R region weight (4th prelong), 
if needed, before allocating
+  double total_wt_r = 0.0;
+  if (preamble_longs == PREAMBLE_LONGS_FULL) { 
+    is.read((char*)&total_wt_r, sizeof(total_wt_r));
+    if (std::isnan(total_wt_r) || r == 0 || total_wt_r <= 0.0) {
+      throw std::invalid_argument("Possible corruption: deserializing in full 
mode but r = 0 or invalid R weight. "
+       "Found r = " + std::to_string(r) + ", R region weight = " + 
std::to_string(total_wt_r));
+    }
+  } else {
+    total_wt_r = 0.0;
+  }
+
+  // read the first h weights, fill remainder with -1.0
+  std::unique_ptr<double, weights_deleter> 
weights(AllocDouble().allocate(array_size), weights_deleter(array_size));
+  double* wts = weights.get(); // to avoid lots of .get() calls -- do not 
delete
+  is.read((char*)wts, h * sizeof(double));
+  for (size_t i = 0; i < h; ++i) {
+    if (!(wts[i] > 0.0)) {
+      throw std::invalid_argument("Possible corruption: Non-positive weight 
when deserializing: " + std::to_string(wts[i]));
+    }
+  }
+  std::fill(&wts[h], &wts[array_size], -1.0);
+
+  // read the first h_ marks as packed bytes iff we have a gadget
+  uint32_t num_marks_in_h = 0;
+  std::unique_ptr<bool, marks_deleter> marks(nullptr, 
marks_deleter(array_size));
+  if (is_gadget) {
+    marks = std::unique_ptr<bool, 
marks_deleter>(AllocBool().allocate(array_size), marks_deleter(array_size));
+    uint8_t val = 0;
+    for (uint32_t i = 0; i < h; ++i) {
+      if ((i & 0x7) == 0x0) { // should trigger on first iteration
+        is.read((char*)&val, sizeof(val));
+      }
+      marks.get()[i] = ((val >> (i & 0x7)) & 0x1) == 1;
+      num_marks_in_h += (marks.get()[i] ? 1 : 0);
+    }
+  }
+
+  // read the sample items, skipping the gap. Either h or r may be 0
+  items_deleter deleter(array_size);
+  std::unique_ptr<T, items_deleter> items(A().allocate(array_size), deleter);
+  
+  S().deserialize(is, items.get(), h); // aka &data_[0]
+  deleter.set_h(h); // serde didn't throw, so the items are now valid
+  
+  S().deserialize(is, &(items.get()[h + 1]), r);
+  deleter.set_r(r); // serde didn't throw, so the items are now valid
+
+  return var_opt_sketch(k, h, (r > 0 ? 1 : 0), r, n, total_wt_r, rf, 
array_size, false,
+                        std::move(items), std::move(weights), num_marks_in_h, 
std::move(marks));
 }
 
 template<typename T, typename S, typename A>
@@ -1299,46 +1452,50 @@ void 
var_opt_sketch<T,S,A>::check_family_and_serialization_version(uint8_t famil
 }
 
 template<typename T, typename S, typename A>
-void var_opt_sketch<T, S, A>::validate_and_set_current_size(uint32_t 
preamble_longs) {
-  if (k_ == 0 || k_ > MAX_K) {
+uint32_t var_opt_sketch<T, S, A>::validate_and_get_target_size(uint32_t 
preamble_longs, uint32_t k, uint64_t n,
+                                                               uint32_t h, 
uint32_t r, resize_factor rf) {
+  if (k == 0 || k > MAX_K) {
     throw std::invalid_argument("k must be at least 1 and less than 2^31 - 1");
   }
 
-  if (n_ <= k_) {
+  uint32_t array_size;
+
+  if (n <= k) {
     if (preamble_longs != PREAMBLE_LONGS_WARMUP) {
       throw std::invalid_argument("Possible corruption: deserializing with n 
<= k but not in warmup mode. "
-       "Found n = " + std::to_string(n_) + ", k = " + std::to_string(k_));
+       "Found n = " + std::to_string(n) + ", k = " + std::to_string(k));
     }
-    if (n_ != h_) {
+    if (n != h) {
       throw std::invalid_argument("Possible corruption: deserializing in 
warmup mode but n != h. "
-       "Found n = " + std::to_string(n_) + ", h = " + std::to_string(h_));
+       "Found n = " + std::to_string(n) + ", h = " + std::to_string(h));
     }
-    if (r_ > 0) {
+    if (r > 0) {
       throw std::invalid_argument("Possible corruption: deserializing in 
warmup mode but r > 0. "
-       "Found r = " + std::to_string(r_));
+       "Found r = " + std::to_string(r));
     }
 
-    const uint32_t ceiling_lg_k = to_log_2(ceiling_power_of_2(k_));
-    const uint32_t min_lg_size = to_log_2(ceiling_power_of_2(h_));
-    const uint32_t initial_lg_size = starting_sub_multiple(ceiling_lg_k, rf_, 
min_lg_size);
-    curr_items_alloc_ = get_adjusted_size(k_, 1 << initial_lg_size);
-    if (curr_items_alloc_ == k_) { // if full size, need to leave 1 for the gap
-      ++curr_items_alloc_;
+    const uint32_t ceiling_lg_k = to_log_2(ceiling_power_of_2(k));
+    const uint32_t min_lg_size = to_log_2(ceiling_power_of_2(h));
+    const uint32_t initial_lg_size = starting_sub_multiple(ceiling_lg_k, rf, 
min_lg_size);
+    array_size = get_adjusted_size(k, 1 << initial_lg_size);
+    if (array_size == k) { // if full size, need to leave 1 for the gap
+      ++array_size;
     }
-  } else { // n_ > k_
+  } else { // n > k
     if (preamble_longs != PREAMBLE_LONGS_FULL) { 
       throw std::invalid_argument("Possible corruption: deserializing with n > 
k but not in full mode. "
-       "Found n = " + std::to_string(n_) + ", k = " + std::to_string(k_));
+       "Found n = " + std::to_string(n) + ", k = " + std::to_string(k));
     }
-    if (h_ + r_ != k_) {
+    if (h + r != k) {
       throw std::invalid_argument("Possible corruption: deserializing in full 
mode but h + r != n. "
-       "Found h = " + std::to_string(h_) + ", r = " + std::to_string(r_) + ", 
n = " + std::to_string(n_));
+       "Found h = " + std::to_string(h) + ", r = " + std::to_string(r) + ", n 
= " + std::to_string(n));
     }
 
-    curr_items_alloc_ = k_ + 1;
+    array_size = k + 1;
   }
-}
 
+  return array_size;
+}
 
 template<typename T, typename S, typename A>
 template<typename P>
@@ -1389,6 +1546,61 @@ subset_summary var_opt_sketch<T, S, 
A>::estimate_subset_sum(P predicate) const {
 }
 
 template<typename T, typename S, typename A>
+class var_opt_sketch<T, S, A>::items_deleter {
+  public:
+  items_deleter(uint32_t num) : num(num), h_count(0), r_count(0) {}
+  void set_h(uint32_t h) { h_count = h; }
+  void set_r(uint32_t r) { r_count = r; }  
+  void operator() (T* ptr) const {
+    if (h_count > 0) {
+      for (size_t i = 0; i < h_count; ++i) {
+        ptr[i].~T();
+      }
+    }
+    if (r_count > 0) {
+      uint32_t end = h_count + r_count + 1;
+      for (size_t i = h_count + 1; i < end; ++i) {
+        ptr[i].~T();
+      }
+    }
+    if (ptr != nullptr) {
+      A().deallocate(ptr, num);
+    }
+  }
+  private:
+  uint32_t num;
+  uint32_t h_count;
+  uint32_t r_count;
+};
+
+template<typename T, typename S, typename A>
+class var_opt_sketch<T, S, A>::weights_deleter {
+  public:
+  weights_deleter(uint32_t num) : num(num) {}
+  void operator() (double* ptr) const {
+    if (ptr != nullptr) {
+      AllocDouble().deallocate(ptr, num);
+    }
+  }
+  private:
+  uint32_t num;
+};
+
+template<typename T, typename S, typename A>
+class var_opt_sketch<T, S, A>::marks_deleter {
+  public:
+  marks_deleter(uint32_t num) : num(num) {}
+  void operator() (bool* ptr) const {
+    if (ptr != nullptr) {
+      AllocBool().deallocate(ptr, 1);
+    }
+  }
+  private:
+  uint32_t num;
+};
+
+
+template<typename T, typename S, typename A>
 typename var_opt_sketch<T, S, A>::const_iterator var_opt_sketch<T, S, 
A>::begin() const {
   return var_opt_sketch<T, S, A>::const_iterator(*this, false);
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to