This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch vo_exception_safety in repository https://gitbox.apache.org/repos/asf/incubator-datasketches-cpp.git
commit b4314e9c0c9da5c42d1e80a3da324db89adf9d02 Author: Jon Malkin <[email protected]> AuthorDate: Fri May 8 13:33:17 2020 -0700 add missing bounds check to varopt union, add exception-handling tests --- sampling/include/var_opt_sketch_impl.hpp | 122 ------------------------------- sampling/include/var_opt_union_impl.hpp | 1 + sampling/test/var_opt_sketch_test.cpp | 27 +++++++ sampling/test/var_opt_union_test.cpp | 15 +++- 4 files changed, 42 insertions(+), 123 deletions(-) diff --git a/sampling/include/var_opt_sketch_impl.hpp b/sampling/include/var_opt_sketch_impl.hpp index 7b56510..b24f1f5 100644 --- a/sampling/include/var_opt_sketch_impl.hpp +++ b/sampling/include/var_opt_sketch_impl.hpp @@ -264,128 +264,6 @@ var_opt_sketch<T,S,A>& var_opt_sketch<T,S,A>::operator=(var_opt_sketch&& other) } /* -template<typename T, typename S, typename A> -var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t preamble_longs, std::istream& is) : - k_(k), m_(0), rf_(rf) { - - // second and third prelongs - is.read((char*)&n_, sizeof(uint64_t)); - is.read((char*)&h_, sizeof(uint32_t)); - is.read((char*)&r_, sizeof(uint32_t)); - - validate_and_set_current_size(preamble_longs); - - // current_items_alloc_ is set but validate R region weight (4th prelong), if needed, before allocating - if (preamble_longs == PREAMBLE_LONGS_FULL) { - is.read((char*)&total_wt_r_, sizeof(total_wt_r_)); - if (std::isnan(total_wt_r_) || r_ == 0 || total_wt_r_ <= 0.0) { - throw std::invalid_argument("Possible corruption: deserializing in full mode but r = 0 or invalid R weight. " - "Found r = " + std::to_string(r_) + ", R region weight = " + std::to_string(total_wt_r_)); - } - } else { - total_wt_r_ = 0.0; - } - - allocate_data_arrays(curr_items_alloc_, is_gadget); - - // read the first h_ weights - is.read((char*)weights_, h_ * sizeof(double)); - for (size_t i = 0; i < h_; ++i) { - if (!(weights_[i] > 0.0)) { - const std::string msg("Possible corruption: Non-positive weight when deserializing: " + std::to_string(weights_[i])); - A().deallocate(data_, curr_items_alloc_); - AllocDouble().deallocate(weights_, curr_items_alloc_); - if (marks_ != nullptr) { AllocBool().deallocate(marks_, curr_items_alloc_); } - throw std::invalid_argument(msg); - } - } - - std::fill(&weights_[h_], &weights_[curr_items_alloc_], -1.0); - - // read the first h_ marks as packed bytes iff we have a gadget - num_marks_in_h_ = 0; - if (is_gadget) { - uint8_t val = 0; - for (uint32_t i = 0; i < h_; ++i) { - if ((i & 0x7) == 0x0) { // should trigger on first iteration - is.read((char*)&val, sizeof(val)); - } - marks_[i] = ((val >> (i & 0x7)) & 0x1) == 1; - num_marks_in_h_ += (marks_[i] ? 1 : 0); - } - } - - // read the sample items, skipping the gap. Either h_ or r_ may be 0 - S().deserialize(is, data_, h_); // aka &data_[0] - S().deserialize(is, &data_[h_ + 1], r_); -} - - -template<typename T, typename S, typename A> -var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool is_gadget, uint8_t preamble_longs, - const void* bytes, size_t size) : k_(k), m_(0), rf_(rf) { - // private constructor so we assume not called if sketch is empty, - // and that the array is large enough to hold the preamble - const char* base = static_cast<const char*>(bytes); - const char* ptr = static_cast<const char*>(bytes) + sizeof(uint64_t); - const char* end_ptr = static_cast<const char*>(bytes) + size; - - // second and third prelongs - ptr += copy_from_mem(ptr, &n_, sizeof(n_)); - ptr += copy_from_mem(ptr, &h_, sizeof(h_)); - ptr += copy_from_mem(ptr, &r_, sizeof(r_)); - - validate_and_set_current_size(preamble_longs); - - // current_items_alloc_ is set but validate R region weight (4th prelong), if needed, before allocating - if (preamble_longs == PREAMBLE_LONGS_FULL) { - ptr += copy_from_mem(ptr, &total_wt_r_, sizeof(total_wt_r_)); - if (std::isnan(total_wt_r_) || r_ == 0 || total_wt_r_ <= 0.0) { - throw std::invalid_argument("Possible corruption: deserializing in full mode but r = 0 or invalid R weight. " - "Found r = " + std::to_string(r_) + ", R region weight = " + std::to_string(total_wt_r_)); - } - } else { - total_wt_r_ = 0.0; - } - - allocate_data_arrays(curr_items_alloc_, is_gadget); - - // read the first h_ weights, fill in rest of array with -1.0 - check_memory_size(ptr - base + (h_ * sizeof(double)), size); - ptr += copy_from_mem(ptr, weights_, h_ * sizeof(double)); - for (size_t i = 0; i < h_; ++i) { - if (!(weights_[i] > 0.0)) { - const std::string msg("Possible corruption: Non-positive weight when deserializing: " + std::to_string(weights_[i])); - A().deallocate(data_, curr_items_alloc_); - AllocDouble().deallocate(weights_, curr_items_alloc_); - if (marks_ != nullptr) { AllocBool().deallocate(marks_, curr_items_alloc_); } - throw std::invalid_argument(msg); - } - } - std::fill(&weights_[h_], &weights_[curr_items_alloc_], -1.0); - - // read the first h_ marks as packed bytes iff we have a gadget - num_marks_in_h_ = 0; - if (is_gadget) { - uint8_t val = 0; - const size_t size_marks = (h_ / 8) + (h_ % 8 > 0 ? 1 : 0); - check_memory_size(ptr - base + size_marks, size); - for (uint32_t i = 0; i < h_; ++i) { - if ((i & 0x7) == 0x0) { // should trigger on first iteration - ptr += copy_from_mem(ptr, &val, sizeof(val)); - } - marks_[i] = ((val >> (i & 0x7)) & 0x1) == 1; - num_marks_in_h_ += (marks_[i] ? 1 : 0); - } - } - - // read the sample items, skipping the gap. Either h_ or r_ may be 0 - ptr += S().deserialize(ptr, end_ptr - ptr, data_, h_); - ptr += S().deserialize(ptr, end_ptr - ptr, &data_[h_ + 1], r_); -} -*/ - -/* * An empty sketch requires 8 bytes. * * <pre> diff --git a/sampling/include/var_opt_union_impl.hpp b/sampling/include/var_opt_union_impl.hpp index f2098dd..5897557 100644 --- a/sampling/include/var_opt_union_impl.hpp +++ b/sampling/include/var_opt_union_impl.hpp @@ -167,6 +167,7 @@ var_opt_union<T,S,A> var_opt_union<T,S,A>::deserialize(std::istream& is) { template<typename T, typename S, typename A> var_opt_union<T,S,A> var_opt_union<T,S,A>::deserialize(const void* bytes, size_t size) { + ensure_minimum_memory(size, 8); const char* ptr = static_cast<const char*>(bytes); uint8_t preamble_longs; ptr += copy_from_mem(ptr, &preamble_longs, sizeof(preamble_longs)); diff --git a/sampling/test/var_opt_sketch_test.cpp b/sampling/test/var_opt_sketch_test.cpp index d912a2a..8c5e15a 100644 --- a/sampling/test/var_opt_sketch_test.cpp +++ b/sampling/test/var_opt_sketch_test.cpp @@ -69,6 +69,13 @@ static void check_if_equal(var_opt_sketch<T,S,A>& sk1, var_opt_sketch<T,S,A>& sk REQUIRE((it1 == sk1.end() && it2 == sk2.end())); // iterators must end at the same time } +static std::stringstream create_stringstream_with_length(std::vector<uint8_t> bytes, size_t length) { + std::stringstream ss(std::ios::in | std::ios::out | std::ios::binary); + std::string str((char*)&bytes[0], length); + ss.str(str); + return ss; +} + TEST_CASE("varopt sketch: invalid k", "[var_opt_sketch]") { REQUIRE_THROWS_AS(var_opt_sketch<int>(0), std::invalid_argument); REQUIRE_THROWS_AS(var_opt_sketch<int>(1 << 31), std::invalid_argument); // aka k < 0 @@ -239,6 +246,11 @@ TEST_CASE("varopt sketch: under-full sketch serialization", "[var_opt_sketch]") sk.serialize(ss); var_opt_sketch<int> sk_from_stream = var_opt_sketch<int>::deserialize(ss); check_if_equal(sk, sk_from_stream); + + // ensure we unroll properly + REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(bytes.data(), bytes.size() - 1), std::out_of_range); + std::stringstream ss_trunc = create_stringstream_with_length(bytes, bytes.size() - 1); + REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(ss_trunc), std::runtime_error); } TEST_CASE("varopt sketch: end-of-warmup sketch serialization", "[var_opt_sketch]") { @@ -255,6 +267,11 @@ TEST_CASE("varopt sketch: end-of-warmup sketch serialization", "[var_opt_sketch] sk.serialize(ss); var_opt_sketch<int> sk_from_stream = var_opt_sketch<int>::deserialize(ss); check_if_equal(sk, sk_from_stream); + + // ensure we unroll properly + REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(bytes.data(), bytes.size() - 1000), std::out_of_range); + std::stringstream ss_trunc = create_stringstream_with_length(bytes, bytes.size() - 1000); + REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(ss_trunc), std::runtime_error); } TEST_CASE("varopt sketch: full sketch serialization", "[var_opt_sketch]") { @@ -283,6 +300,11 @@ TEST_CASE("varopt sketch: full sketch serialization", "[var_opt_sketch]") { sk.serialize(ss); var_opt_sketch<int> sk_from_stream = var_opt_sketch<int>::deserialize(ss); check_if_equal(sk, sk_from_stream); + + // ensure we unroll properly + REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(bytes.data(), bytes.size() - 100), std::out_of_range); + std::stringstream ss_trunc = create_stringstream_with_length(bytes, bytes.size() - 100); + REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(ss_trunc), std::runtime_error); } TEST_CASE("varopt sketch: string serialization", "[var_opt_sketch]") { @@ -302,6 +324,11 @@ TEST_CASE("varopt sketch: string serialization", "[var_opt_sketch]") { sk.serialize(ss); var_opt_sketch<std::string> sk_from_stream = var_opt_sketch<std::string>::deserialize(ss); check_if_equal(sk, sk_from_stream); + + // ensure we unroll properly + REQUIRE_THROWS_AS(var_opt_sketch<std::string>::deserialize(bytes.data(), bytes.size() - 12), std::out_of_range); + std::stringstream ss_trunc = create_stringstream_with_length(bytes, bytes.size() - 12); + REQUIRE_THROWS_AS(var_opt_sketch<std::string>::deserialize(ss_trunc), std::runtime_error); } TEST_CASE("varopt sketch: pseudo-light update", "[var_opt_sketch]") { diff --git a/sampling/test/var_opt_union_test.cpp b/sampling/test/var_opt_union_test.cpp index 62440d4..e2321c5 100644 --- a/sampling/test/var_opt_union_test.cpp +++ b/sampling/test/var_opt_union_test.cpp @@ -54,6 +54,13 @@ static void check_if_equal(var_opt_sketch<T,S,A>& sk1, var_opt_sketch<T,S,A>& sk REQUIRE((it1 == sk1.end() && it2 == sk2.end())); // iterators must end at the same time } +static std::stringstream create_stringstream_with_length(std::vector<uint8_t> bytes, size_t length) { + std::stringstream ss(std::ios::in | std::ios::out | std::ios::binary); + std::string str((char*)&bytes[0], length); + ss.str(str); + return ss; +} + // compare serialization and deserialization results, checking string and stream methods to // ensure that the resulting binary images are compatible. // if exact_compare = false, checks for equivalence -- specific R region values may differ but @@ -85,9 +92,15 @@ static void compare_serialization_deserialization(var_opt_union<T,S,A>& vo_union var_opt_union<T> u_from_str = var_opt_union<T>::deserialize(str_from_stream.c_str(), str_from_stream.size()); sk2 = u_from_str.get_result(); check_if_equal(sk1, sk2, exact_compare); + + // check truncated input, too + REQUIRE_THROWS_AS(var_opt_union<T>::deserialize(bytes.data(), bytes.size() - 5), std::out_of_range); + std::stringstream ss_trunc = create_stringstream_with_length(bytes, bytes.size() - 5); + // next line may throw either std::illegal_argument or std::runtime_exception + REQUIRE_THROWS_AS(var_opt_union<T>::deserialize(ss_trunc), std::exception); } -TEST_CASE("varopt union: bad predlongs", "[var_opt_union]") { +TEST_CASE("varopt union: bad prelongs", "[var_opt_union]") { var_opt_sketch<int> sk = create_unweighted_sketch(32, 33); var_opt_union<int> u(32); u.update(sk); --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
