This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch vo_exception_safety
in repository https://gitbox.apache.org/repos/asf/incubator-datasketches-cpp.git

commit b4314e9c0c9da5c42d1e80a3da324db89adf9d02
Author: Jon Malkin <[email protected]>
AuthorDate: Fri May 8 13:33:17 2020 -0700

    add missing bounds check to varopt union, add exception-handling tests
---
 sampling/include/var_opt_sketch_impl.hpp | 122 -------------------------------
 sampling/include/var_opt_union_impl.hpp  |   1 +
 sampling/test/var_opt_sketch_test.cpp    |  27 +++++++
 sampling/test/var_opt_union_test.cpp     |  15 +++-
 4 files changed, 42 insertions(+), 123 deletions(-)

diff --git a/sampling/include/var_opt_sketch_impl.hpp 
b/sampling/include/var_opt_sketch_impl.hpp
index 7b56510..b24f1f5 100644
--- a/sampling/include/var_opt_sketch_impl.hpp
+++ b/sampling/include/var_opt_sketch_impl.hpp
@@ -264,128 +264,6 @@ var_opt_sketch<T,S,A>& 
var_opt_sketch<T,S,A>::operator=(var_opt_sketch&& other)
 }
 
 /*
-template<typename T, typename S, typename A>
-var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool 
is_gadget, uint8_t preamble_longs, std::istream& is) :
-  k_(k), m_(0), rf_(rf) {
-
-  // second and third prelongs
-  is.read((char*)&n_, sizeof(uint64_t));
-  is.read((char*)&h_, sizeof(uint32_t));
-  is.read((char*)&r_, sizeof(uint32_t));
-
-  validate_and_set_current_size(preamble_longs);
-
-  // current_items_alloc_ is set but validate R region weight (4th prelong), 
if needed, before allocating
-  if (preamble_longs == PREAMBLE_LONGS_FULL) { 
-    is.read((char*)&total_wt_r_, sizeof(total_wt_r_));
-    if (std::isnan(total_wt_r_) || r_ == 0 || total_wt_r_ <= 0.0) {
-      throw std::invalid_argument("Possible corruption: deserializing in full 
mode but r = 0 or invalid R weight. "
-       "Found r = " + std::to_string(r_) + ", R region weight = " + 
std::to_string(total_wt_r_));
-    }
-  } else {
-    total_wt_r_ = 0.0;
-  }
-
-  allocate_data_arrays(curr_items_alloc_, is_gadget);
-
-  // read the first h_ weights
-  is.read((char*)weights_, h_ * sizeof(double));
-  for (size_t i = 0; i < h_; ++i) {
-    if (!(weights_[i] > 0.0)) {
-      const std::string msg("Possible corruption: Non-positive weight when 
deserializing: " + std::to_string(weights_[i]));
-      A().deallocate(data_, curr_items_alloc_);
-      AllocDouble().deallocate(weights_, curr_items_alloc_);
-      if (marks_ != nullptr) { AllocBool().deallocate(marks_, 
curr_items_alloc_); }
-      throw std::invalid_argument(msg);
-    }
-  }
-
-  std::fill(&weights_[h_], &weights_[curr_items_alloc_], -1.0);
-
-  // read the first h_ marks as packed bytes iff we have a gadget
-  num_marks_in_h_ = 0;
-  if (is_gadget) {
-    uint8_t val = 0;
-    for (uint32_t i = 0; i < h_; ++i) {
-      if ((i & 0x7) == 0x0) { // should trigger on first iteration
-        is.read((char*)&val, sizeof(val));
-      }
-      marks_[i] = ((val >> (i & 0x7)) & 0x1) == 1;
-      num_marks_in_h_ += (marks_[i] ? 1 : 0);
-    }
-  }
-
-  // read the sample items, skipping the gap. Either h_ or r_ may be 0
-  S().deserialize(is, data_, h_); // aka &data_[0]
-  S().deserialize(is, &data_[h_ + 1], r_);
-}
-
-
-template<typename T, typename S, typename A>
-var_opt_sketch<T,S,A>::var_opt_sketch(uint32_t k, resize_factor rf, bool 
is_gadget, uint8_t preamble_longs,
-                                      const void* bytes, size_t size) : k_(k), 
m_(0), rf_(rf) {
-  // private constructor so we assume not called if sketch is empty,
-  // and that the array is large enough to hold the preamble
-  const char* base = static_cast<const char*>(bytes);
-  const char* ptr = static_cast<const char*>(bytes) + sizeof(uint64_t);
-  const char* end_ptr = static_cast<const char*>(bytes) + size;
-
-  // second and third prelongs
-  ptr += copy_from_mem(ptr, &n_, sizeof(n_));
-  ptr += copy_from_mem(ptr, &h_, sizeof(h_));
-  ptr += copy_from_mem(ptr, &r_, sizeof(r_));
-
-  validate_and_set_current_size(preamble_longs);
-  
-  // current_items_alloc_ is set but validate R region weight (4th prelong), 
if needed, before allocating
-  if (preamble_longs == PREAMBLE_LONGS_FULL) {
-    ptr += copy_from_mem(ptr, &total_wt_r_, sizeof(total_wt_r_));
-    if (std::isnan(total_wt_r_) || r_ == 0 || total_wt_r_ <= 0.0) {
-      throw std::invalid_argument("Possible corruption: deserializing in full 
mode but r = 0 or invalid R weight. "
-       "Found r = " + std::to_string(r_) + ", R region weight = " + 
std::to_string(total_wt_r_));
-    }
-  } else {
-    total_wt_r_ = 0.0;
-  }
-
-  allocate_data_arrays(curr_items_alloc_, is_gadget);
-
-  // read the first h_ weights, fill in rest of array with -1.0
-  check_memory_size(ptr - base + (h_ * sizeof(double)), size);
-  ptr += copy_from_mem(ptr, weights_, h_ * sizeof(double));
-  for (size_t i = 0; i < h_; ++i) {
-    if (!(weights_[i] > 0.0)) {
-      const std::string msg("Possible corruption: Non-positive weight when 
deserializing: " + std::to_string(weights_[i]));
-      A().deallocate(data_, curr_items_alloc_);
-      AllocDouble().deallocate(weights_, curr_items_alloc_);
-      if (marks_ != nullptr) { AllocBool().deallocate(marks_, 
curr_items_alloc_); }
-      throw std::invalid_argument(msg);
-    }
-  }
-  std::fill(&weights_[h_], &weights_[curr_items_alloc_], -1.0);
-  
-  // read the first h_ marks as packed bytes iff we have a gadget
-  num_marks_in_h_ = 0;
-  if (is_gadget) {
-    uint8_t val = 0;
-    const size_t size_marks = (h_ / 8) + (h_ % 8 > 0 ? 1 : 0);
-    check_memory_size(ptr - base + size_marks, size);
-    for (uint32_t i = 0; i < h_; ++i) {
-     if ((i & 0x7) == 0x0) { // should trigger on first iteration
-        ptr += copy_from_mem(ptr, &val, sizeof(val));
-      }
-      marks_[i] = ((val >> (i & 0x7)) & 0x1) == 1;
-      num_marks_in_h_ += (marks_[i] ? 1 : 0);
-    }
-  }
-
-  // read the sample items, skipping the gap. Either h_ or r_ may be 0
-  ptr += S().deserialize(ptr, end_ptr - ptr, data_, h_);
-  ptr += S().deserialize(ptr, end_ptr - ptr, &data_[h_ + 1], r_);
-}
-*/
-
-/*
  * An empty sketch requires 8 bytes.
  *
  * <pre>
diff --git a/sampling/include/var_opt_union_impl.hpp 
b/sampling/include/var_opt_union_impl.hpp
index f2098dd..5897557 100644
--- a/sampling/include/var_opt_union_impl.hpp
+++ b/sampling/include/var_opt_union_impl.hpp
@@ -167,6 +167,7 @@ var_opt_union<T,S,A> 
var_opt_union<T,S,A>::deserialize(std::istream& is) {
 
 template<typename T, typename S, typename A>
 var_opt_union<T,S,A> var_opt_union<T,S,A>::deserialize(const void* bytes, 
size_t size) {
+  ensure_minimum_memory(size, 8);
   const char* ptr = static_cast<const char*>(bytes);
   uint8_t preamble_longs;
   ptr += copy_from_mem(ptr, &preamble_longs, sizeof(preamble_longs));
diff --git a/sampling/test/var_opt_sketch_test.cpp 
b/sampling/test/var_opt_sketch_test.cpp
index d912a2a..8c5e15a 100644
--- a/sampling/test/var_opt_sketch_test.cpp
+++ b/sampling/test/var_opt_sketch_test.cpp
@@ -69,6 +69,13 @@ static void check_if_equal(var_opt_sketch<T,S,A>& sk1, 
var_opt_sketch<T,S,A>& sk
   REQUIRE((it1 == sk1.end() && it2 == sk2.end())); // iterators must end at 
the same time
 }
 
+static std::stringstream create_stringstream_with_length(std::vector<uint8_t> 
bytes, size_t length) {
+  std::stringstream ss(std::ios::in | std::ios::out | std::ios::binary);
+  std::string str((char*)&bytes[0], length);
+  ss.str(str);
+  return ss;
+}
+
 TEST_CASE("varopt sketch: invalid k", "[var_opt_sketch]") {
   REQUIRE_THROWS_AS(var_opt_sketch<int>(0), std::invalid_argument);
   REQUIRE_THROWS_AS(var_opt_sketch<int>(1 << 31), std::invalid_argument); // 
aka k < 0
@@ -239,6 +246,11 @@ TEST_CASE("varopt sketch: under-full sketch 
serialization", "[var_opt_sketch]")
   sk.serialize(ss);
   var_opt_sketch<int> sk_from_stream = var_opt_sketch<int>::deserialize(ss);
   check_if_equal(sk, sk_from_stream);
+
+  // ensure we unroll properly
+  REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(bytes.data(), 
bytes.size() - 1), std::out_of_range);
+  std::stringstream ss_trunc = create_stringstream_with_length(bytes, 
bytes.size() - 1);
+  REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(ss_trunc), 
std::runtime_error);
 }
 
 TEST_CASE("varopt sketch: end-of-warmup sketch serialization", 
"[var_opt_sketch]") {
@@ -255,6 +267,11 @@ TEST_CASE("varopt sketch: end-of-warmup sketch 
serialization", "[var_opt_sketch]
   sk.serialize(ss);
   var_opt_sketch<int> sk_from_stream = var_opt_sketch<int>::deserialize(ss);
   check_if_equal(sk, sk_from_stream);
+
+  // ensure we unroll properly
+  REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(bytes.data(), 
bytes.size() - 1000), std::out_of_range);
+  std::stringstream ss_trunc = create_stringstream_with_length(bytes, 
bytes.size() - 1000);
+  REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(ss_trunc), 
std::runtime_error);
 }
 
 TEST_CASE("varopt sketch: full sketch serialization", "[var_opt_sketch]") {
@@ -283,6 +300,11 @@ TEST_CASE("varopt sketch: full sketch serialization", 
"[var_opt_sketch]") {
   sk.serialize(ss);
   var_opt_sketch<int> sk_from_stream = var_opt_sketch<int>::deserialize(ss);
   check_if_equal(sk, sk_from_stream);
+
+  // ensure we unroll properly
+  REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(bytes.data(), 
bytes.size() - 100), std::out_of_range);
+  std::stringstream ss_trunc = create_stringstream_with_length(bytes, 
bytes.size() - 100);
+  REQUIRE_THROWS_AS(var_opt_sketch<int>::deserialize(ss_trunc), 
std::runtime_error);
 }
 
 TEST_CASE("varopt sketch: string serialization", "[var_opt_sketch]") {
@@ -302,6 +324,11 @@ TEST_CASE("varopt sketch: string serialization", 
"[var_opt_sketch]") {
   sk.serialize(ss);
   var_opt_sketch<std::string> sk_from_stream = 
var_opt_sketch<std::string>::deserialize(ss);
   check_if_equal(sk, sk_from_stream);
+
+  // ensure we unroll properly
+  REQUIRE_THROWS_AS(var_opt_sketch<std::string>::deserialize(bytes.data(), 
bytes.size() - 12), std::out_of_range);
+  std::stringstream ss_trunc = create_stringstream_with_length(bytes, 
bytes.size() - 12);
+  REQUIRE_THROWS_AS(var_opt_sketch<std::string>::deserialize(ss_trunc), 
std::runtime_error);
 }
 
 TEST_CASE("varopt sketch: pseudo-light update", "[var_opt_sketch]") {
diff --git a/sampling/test/var_opt_union_test.cpp 
b/sampling/test/var_opt_union_test.cpp
index 62440d4..e2321c5 100644
--- a/sampling/test/var_opt_union_test.cpp
+++ b/sampling/test/var_opt_union_test.cpp
@@ -54,6 +54,13 @@ static void check_if_equal(var_opt_sketch<T,S,A>& sk1, 
var_opt_sketch<T,S,A>& sk
   REQUIRE((it1 == sk1.end() && it2 == sk2.end())); // iterators must end at 
the same time
 }
 
+static std::stringstream create_stringstream_with_length(std::vector<uint8_t> 
bytes, size_t length) {
+  std::stringstream ss(std::ios::in | std::ios::out | std::ios::binary);
+  std::string str((char*)&bytes[0], length);
+  ss.str(str);
+  return ss;
+}
+
 // compare serialization and deserialization results, checking string and 
stream methods to
 // ensure that the resulting binary images are compatible.
 // if exact_compare = false, checks for equivalence -- specific R region 
values may differ but
@@ -85,9 +92,15 @@ static void 
compare_serialization_deserialization(var_opt_union<T,S,A>& vo_union
   var_opt_union<T> u_from_str = 
var_opt_union<T>::deserialize(str_from_stream.c_str(), str_from_stream.size());
   sk2 = u_from_str.get_result();
   check_if_equal(sk1, sk2, exact_compare);
+
+  // check truncated input, too
+  REQUIRE_THROWS_AS(var_opt_union<T>::deserialize(bytes.data(), bytes.size() - 
5), std::out_of_range);
+  std::stringstream ss_trunc = create_stringstream_with_length(bytes, 
bytes.size() - 5);
+  // next line may throw either std::illegal_argument or std::runtime_exception
+  REQUIRE_THROWS_AS(var_opt_union<T>::deserialize(ss_trunc), std::exception);
 }
 
-TEST_CASE("varopt union: bad predlongs", "[var_opt_union]") {
+TEST_CASE("varopt union: bad prelongs", "[var_opt_union]") {
   var_opt_sketch<int> sk = create_unweighted_sketch(32, 33);
   var_opt_union<int> u(32);
   u.update(sk);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to