This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch vo_exception_safety in repository https://gitbox.apache.org/repos/asf/incubator-datasketches-cpp.git
commit cd747bf79cf3ba7ace6bb5f872554fff12bcd51d Author: Jon Malkin <[email protected]> AuthorDate: Fri May 8 01:21:37 2020 -0700 add exception handling to varopt deserialize, untested --- common/include/memory_operations.hpp | 6 +- sampling/include/var_opt_sketch.hpp | 17 +- sampling/include/var_opt_sketch_impl.hpp | 256 ++++++++++++++++++++++++++++--- 3 files changed, 250 insertions(+), 29 deletions(-) diff --git a/common/include/memory_operations.hpp b/common/include/memory_operations.hpp index 6452721..80dc3a3 100644 --- a/common/include/memory_operations.hpp +++ b/common/include/memory_operations.hpp @@ -17,8 +17,8 @@ * under the License. */ -#ifndef _MEMORY_CHECKS_HPP_ -#define _MEMORY_CHECKS_HPP_ +#ifndef _MEMORY_OPERATIONS_HPP_ +#define _MEMORY_OPERATIONS_HPP_ #include <memory> #include <exception> @@ -54,4 +54,4 @@ static inline size_t copy_to_mem(const void* src, void* dst, size_t size) { } // namespace -#endif // _MEMORY_CHECKS_HPP_ +#endif // _MEMORY_OPERATIONS_HPP_ diff --git a/sampling/include/var_opt_sketch.hpp b/sampling/include/var_opt_sketch.hpp index d2f7e9b..1357a3b 100644 --- a/sampling/include/var_opt_sketch.hpp +++ b/sampling/include/var_opt_sketch.hpp @@ -266,9 +266,18 @@ class var_opt_sketch { // occurs and is properly tracked. bool* marks_; + // used during deserialization to avoid memork leaks upon errors + class items_deleter; + class weights_deleter; + class marks_deleter; + var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget); - var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t preamble_longs, std::istream& is); - var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t preamble_longs, const void* bytes, size_t size); + var_opt_sketch(uint32_t k, uint32_t h, uint32_t m, uint32_t r, uint64_t n, double total_wt_r, resize_factor rf, + uint32_t curr_items_alloc, bool filled_data, std::unique_ptr<T, items_deleter> items, + std::unique_ptr<double, weights_deleter> weights, uint32_t num_marks_in_h, + std::unique_ptr<bool, marks_deleter> marks); + //var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t preamble_longs, std::istream& is); + //var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t preamble_longs, const void* bytes, size_t size); friend class var_opt_union<T,S,A>; var_opt_sketch(const var_opt_sketch& other, bool as_sketch, uint64_t adjusted_n); @@ -307,8 +316,8 @@ class var_opt_sketch { // validation static void check_preamble_longs(uint8_t preamble_longs, uint8_t flags); static void check_family_and_serialization_version(uint8_t family_id, uint8_t ser_ver); - // next method sets current_items_alloc_ - void validate_and_set_current_size(uint32_t preamble_longs); + static uint32_t validate_and_get_target_size(uint32_t preamble_longs, uint32_t k, uint64_t n, + uint32_t h, uint32_t r, resize_factor rf); // things to move to common and be shared among sketches static uint32_t get_adjusted_size(uint32_t max_size, uint32_t resize_target); diff --git a/sampling/include/var_opt_sketch_impl.hpp b/sampling/include/var_opt_sketch_impl.hpp index bf08b91..7b56510 100644 --- a/sampling/include/var_opt_sketch_impl.hpp +++ b/sampling/include/var_opt_sketch_impl.hpp @@ -175,6 +175,27 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadg } template<typename T, typename S, typename A> +var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, uint32_t h, uint32_t m, uint32_t r, uint64_t n, double total_wt_r, resize_factor rf, + uint32_t curr_items_alloc, bool filled_data, std::unique_ptr<T, items_deleter> items, + std::unique_ptr<double, weights_deleter> weights, uint32_t num_marks_in_h, + std::unique_ptr<bool, marks_deleter> marks) : + k_(k), + h_(h), + m_(m), + r_(r), + n_(n), + total_wt_r_(total_wt_r), + rf_(rf), + curr_items_alloc_(curr_items_alloc), + filled_data_(filled_data), + data_(items.release()), + weights_(weights.release()), + num_marks_in_h_(num_marks_in_h), + marks_(marks.release()) +{} + + +template<typename T, typename S, typename A> var_opt_sketch<T,S,A>::~var_opt_sketch() { if (data_ != nullptr) { if (filled_data_) { @@ -242,6 +263,7 @@ var_opt_sketch<T,S,A>& var_opt_sketch<T,S,A>::operator=(var_opt_sketch&& other) return *this; } +/* template<typename T, typename S, typename A> var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t preamble_longs, std::istream& is) : k_(k), m_(0), rf_(rf) { @@ -298,6 +320,7 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadg S().deserialize(is, &data_[h_ + 1], r_); } + template<typename T, typename S, typename A> var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t preamble_longs, const void* bytes, size_t size) : k_(k), m_(0), rf_(rf) { @@ -360,6 +383,7 @@ var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadg ptr += S().deserialize(ptr, end_ptr - ptr, data_, h_); ptr += S().deserialize(ptr, end_ptr - ptr, &data_[h_ + 1], r_); } +*/ /* * An empty sketch requires 8 bytes. @@ -569,6 +593,8 @@ template<typename T, typename S, typename A> var_opt_sketch<T,S,A> var_opt_sketch<T,S,A>::deserialize(const void* bytes, size_t size) { ensure_minimum_memory(size, 8); const char* ptr = static_cast<const char*>(bytes); + const char* base = ptr; + const char* end_ptr = ptr + size; uint8_t first_byte; ptr += copy_from_mem(ptr, &first_byte, sizeof(first_byte)); uint8_t preamble_longs = first_byte & 0x3f; @@ -589,7 +615,72 @@ var_opt_sketch<T,S,A> var_opt_sketch<T,S,A>::deserialize(const void* bytes, size const bool is_empty = flags & EMPTY_FLAG_MASK; const bool is_gadget = flags & GADGET_FLAG_MASK; - return is_empty ? var_opt_sketch<T,S,A>(k, rf, is_gadget) : var_opt_sketch<T,S,A>(k, rf, is_gadget, preamble_longs, bytes, size); + if (is_empty) { + return var_opt_sketch<T,S,A>(k, rf, is_gadget); + } + + // second and third prelongs + uint64_t n; + uint32_t h, r; + ptr += copy_from_mem(ptr, &n, sizeof(n)); + ptr += copy_from_mem(ptr, &h, sizeof(h)); + ptr += copy_from_mem(ptr, &r, sizeof(r)); + + const uint32_t array_size = validate_and_get_target_size(preamble_longs, k, n, h, r, rf); + + // current_items_alloc_ is set but validate R region weight (4th prelong), if needed, before allocating + double total_wt_r = 0.0; + if (preamble_longs == PREAMBLE_LONGS_FULL) { + ptr += copy_from_mem(ptr, &total_wt_r, sizeof(total_wt_r)); + if (std::isnan(total_wt_r) || r == 0 || total_wt_r <= 0.0) { + throw std::invalid_argument("Possible corruption: deserializing in full mode but r = 0 or invalid R weight. " + "Found r = " + std::to_string(r) + ", R region weight = " + std::to_string(total_wt_r)); + } + } else { + total_wt_r = 0.0; + } + + // read the first h_ weights, fill in rest of array with -1.0 + check_memory_size(ptr - base + (h * sizeof(double)), size); + std::unique_ptr<double, weights_deleter> weights(AllocDouble().allocate(array_size), weights_deleter(array_size)); + double* wts = weights.get(); // to avoid lots of .get() calls -- do not delete + ptr += copy_from_mem(ptr, wts, h * sizeof(double)); + for (size_t i = 0; i < h; ++i) { + if (!(wts[i] > 0.0)) { + throw std::invalid_argument("Possible corruption: Non-positive weight when deserializing: " + std::to_string(wts[i])); + } + } + std::fill(&wts[h], &wts[array_size], -1.0); + + // read the first h_ marks as packed bytes iff we have a gadget + uint32_t num_marks_in_h = 0; + std::unique_ptr<bool, marks_deleter> marks(nullptr, marks_deleter(array_size)); + if (is_gadget) { + uint8_t val = 0; + marks = std::unique_ptr<bool, marks_deleter>(AllocBool().allocate(array_size), marks_deleter(array_size)); + const size_t size_marks = (h / 8) + (h % 8 > 0 ? 1 : 0); + check_memory_size(ptr - base + size_marks, size); + for (uint32_t i = 0; i < h; ++i) { + if ((i & 0x7) == 0x0) { // should trigger on first iteration + ptr += copy_from_mem(ptr, &val, sizeof(val)); + } + marks.get()[i] = ((val >> (i & 0x7)) & 0x1) == 1; + num_marks_in_h += (marks.get()[i] ? 1 : 0); + } + } + + // read the sample items, skipping the gap. Either h_ or r_ may be 0 + items_deleter deleter(array_size); + std::unique_ptr<T, items_deleter> items(A().allocate(array_size), deleter); + + ptr += S().deserialize(ptr, end_ptr - ptr, items.get(), h); + deleter.set_h(h); // serde didn't throw, so the items are now valid + + ptr += S().deserialize(ptr, end_ptr - ptr, &(items.get()[h + 1]), r); + deleter.set_r(r); // serde didn't throw, so the items are now valid + + return var_opt_sketch(k, h, (r > 0 ? 1 : 0), r, n, total_wt_r, rf, array_size, false, + std::move(items), std::move(weights), num_marks_in_h, std::move(marks)); } template<typename T, typename S, typename A> @@ -613,7 +704,69 @@ var_opt_sketch<T,S,A> var_opt_sketch<T,S,A>::deserialize(std::istream& is) { const bool is_empty = flags & EMPTY_FLAG_MASK; const bool is_gadget = flags & GADGET_FLAG_MASK; - return is_empty ? var_opt_sketch<T,S,A>(k, rf, is_gadget) : var_opt_sketch<T,S,A>(k, rf, is_gadget, preamble_longs, is); + if (is_empty) { + return var_opt_sketch<T,S,A>(k, rf, is_gadget); + } + + // second and third prelongs + uint64_t n; + uint32_t h, r; + is.read((char*)&n, sizeof(n)); + is.read((char*)&h, sizeof(h)); + is.read((char*)&r, sizeof(r)); + + const uint32_t array_size = validate_and_get_target_size(preamble_longs, k, n, h, r, rf); + + // current_items_alloc_ is set but validate R region weight (4th prelong), if needed, before allocating + double total_wt_r = 0.0; + if (preamble_longs == PREAMBLE_LONGS_FULL) { + is.read((char*)&total_wt_r, sizeof(total_wt_r)); + if (std::isnan(total_wt_r) || r == 0 || total_wt_r <= 0.0) { + throw std::invalid_argument("Possible corruption: deserializing in full mode but r = 0 or invalid R weight. " + "Found r = " + std::to_string(r) + ", R region weight = " + std::to_string(total_wt_r)); + } + } else { + total_wt_r = 0.0; + } + + // read the first h weights, fill remainder with -1.0 + std::unique_ptr<double, weights_deleter> weights(AllocDouble().allocate(array_size), weights_deleter(array_size)); + double* wts = weights.get(); // to avoid lots of .get() calls -- do not delete + is.read((char*)wts, h * sizeof(double)); + for (size_t i = 0; i < h; ++i) { + if (!(wts[i] > 0.0)) { + throw std::invalid_argument("Possible corruption: Non-positive weight when deserializing: " + std::to_string(wts[i])); + } + } + std::fill(&wts[h], &wts[array_size], -1.0); + + // read the first h_ marks as packed bytes iff we have a gadget + uint32_t num_marks_in_h = 0; + std::unique_ptr<bool, marks_deleter> marks(nullptr, marks_deleter(array_size)); + if (is_gadget) { + marks = std::unique_ptr<bool, marks_deleter>(AllocBool().allocate(array_size), marks_deleter(array_size)); + uint8_t val = 0; + for (uint32_t i = 0; i < h; ++i) { + if ((i & 0x7) == 0x0) { // should trigger on first iteration + is.read((char*)&val, sizeof(val)); + } + marks.get()[i] = ((val >> (i & 0x7)) & 0x1) == 1; + num_marks_in_h += (marks.get()[i] ? 1 : 0); + } + } + + // read the sample items, skipping the gap. Either h or r may be 0 + items_deleter deleter(array_size); + std::unique_ptr<T, items_deleter> items(A().allocate(array_size), deleter); + + S().deserialize(is, items.get(), h); // aka &data_[0] + deleter.set_h(h); // serde didn't throw, so the items are now valid + + S().deserialize(is, &(items.get()[h + 1]), r); + deleter.set_r(r); // serde didn't throw, so the items are now valid + + return var_opt_sketch(k, h, (r > 0 ? 1 : 0), r, n, total_wt_r, rf, array_size, false, + std::move(items), std::move(weights), num_marks_in_h, std::move(marks)); } template<typename T, typename S, typename A> @@ -1299,46 +1452,50 @@ void var_opt_sketch<T,S,A>::check_family_and_serialization_version(uint8_t famil } template<typename T, typename S, typename A> -void var_opt_sketch<T, S, A>::validate_and_set_current_size(uint32_t preamble_longs) { - if (k_ == 0 || k_ > MAX_K) { +uint32_t var_opt_sketch<T, S, A>::validate_and_get_target_size(uint32_t preamble_longs, uint32_t k, uint64_t n, + uint32_t h, uint32_t r, resize_factor rf) { + if (k == 0 || k > MAX_K) { throw std::invalid_argument("k must be at least 1 and less than 2^31 - 1"); } - if (n_ <= k_) { + uint32_t array_size; + + if (n <= k) { if (preamble_longs != PREAMBLE_LONGS_WARMUP) { throw std::invalid_argument("Possible corruption: deserializing with n <= k but not in warmup mode. " - "Found n = " + std::to_string(n_) + ", k = " + std::to_string(k_)); + "Found n = " + std::to_string(n) + ", k = " + std::to_string(k)); } - if (n_ != h_) { + if (n != h) { throw std::invalid_argument("Possible corruption: deserializing in warmup mode but n != h. " - "Found n = " + std::to_string(n_) + ", h = " + std::to_string(h_)); + "Found n = " + std::to_string(n) + ", h = " + std::to_string(h)); } - if (r_ > 0) { + if (r > 0) { throw std::invalid_argument("Possible corruption: deserializing in warmup mode but r > 0. " - "Found r = " + std::to_string(r_)); + "Found r = " + std::to_string(r)); } - const uint32_t ceiling_lg_k = to_log_2(ceiling_power_of_2(k_)); - const uint32_t min_lg_size = to_log_2(ceiling_power_of_2(h_)); - const uint32_t initial_lg_size = starting_sub_multiple(ceiling_lg_k, rf_, min_lg_size); - curr_items_alloc_ = get_adjusted_size(k_, 1 << initial_lg_size); - if (curr_items_alloc_ == k_) { // if full size, need to leave 1 for the gap - ++curr_items_alloc_; + const uint32_t ceiling_lg_k = to_log_2(ceiling_power_of_2(k)); + const uint32_t min_lg_size = to_log_2(ceiling_power_of_2(h)); + const uint32_t initial_lg_size = starting_sub_multiple(ceiling_lg_k, rf, min_lg_size); + array_size = get_adjusted_size(k, 1 << initial_lg_size); + if (array_size == k) { // if full size, need to leave 1 for the gap + ++array_size; } - } else { // n_ > k_ + } else { // n > k if (preamble_longs != PREAMBLE_LONGS_FULL) { throw std::invalid_argument("Possible corruption: deserializing with n > k but not in full mode. " - "Found n = " + std::to_string(n_) + ", k = " + std::to_string(k_)); + "Found n = " + std::to_string(n) + ", k = " + std::to_string(k)); } - if (h_ + r_ != k_) { + if (h + r != k) { throw std::invalid_argument("Possible corruption: deserializing in full mode but h + r != n. " - "Found h = " + std::to_string(h_) + ", r = " + std::to_string(r_) + ", n = " + std::to_string(n_)); + "Found h = " + std::to_string(h) + ", r = " + std::to_string(r) + ", n = " + std::to_string(n)); } - curr_items_alloc_ = k_ + 1; + array_size = k + 1; } -} + return array_size; +} template<typename T, typename S, typename A> template<typename P> @@ -1389,6 +1546,61 @@ subset_summary var_opt_sketch<T, S, A>::estimate_subset_sum(P predicate) const { } template<typename T, typename S, typename A> +class var_opt_sketch<T, S, A>::items_deleter { + public: + items_deleter(uint32_t num) : num(num), h_count(0), r_count(0) {} + void set_h(uint32_t h) { h_count = h; } + void set_r(uint32_t r) { r_count = r; } + void operator() (T* ptr) const { + if (h_count > 0) { + for (size_t i = 0; i < h_count; ++i) { + ptr[i].~T(); + } + } + if (r_count > 0) { + uint32_t end = h_count + r_count + 1; + for (size_t i = h_count + 1; i < end; ++i) { + ptr[i].~T(); + } + } + if (ptr != nullptr) { + A().deallocate(ptr, num); + } + } + private: + uint32_t num; + uint32_t h_count; + uint32_t r_count; +}; + +template<typename T, typename S, typename A> +class var_opt_sketch<T, S, A>::weights_deleter { + public: + weights_deleter(uint32_t num) : num(num) {} + void operator() (double* ptr) const { + if (ptr != nullptr) { + AllocDouble().deallocate(ptr, num); + } + } + private: + uint32_t num; +}; + +template<typename T, typename S, typename A> +class var_opt_sketch<T, S, A>::marks_deleter { + public: + marks_deleter(uint32_t num) : num(num) {} + void operator() (bool* ptr) const { + if (ptr != nullptr) { + AllocBool().deallocate(ptr, 1); + } + } + private: + uint32_t num; +}; + + +template<typename T, typename S, typename A> typename var_opt_sketch<T, S, A>::const_iterator var_opt_sketch<T, S, A>::begin() const { return var_opt_sketch<T, S, A>::const_iterator(*this, false); } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
