This is an automated email from the ASF dual-hosted git repository.

alsay pushed a commit to branch tdigest
in repository https://gitbox.apache.org/repos/asf/datasketches-cpp.git

commit dc032a760a816389de52c8edd274f1a4f83bb36f
Author: AlexanderSaydakov <[email protected]>
AuthorDate: Mon Feb 5 17:46:20 2024 -0800

    deserialize format of the reference implementation
---
 common/include/common_defs.hpp   |  17 +++++++
 tdigest/include/tdigest.hpp      |   7 +++
 tdigest/include/tdigest_impl.hpp | 107 +++++++++++++++++++++++++++++++++++++++
 tdigest/test/tdigest_test.cpp    |  67 ++++++++++++++++++++++++
 4 files changed, 198 insertions(+)

diff --git a/common/include/common_defs.hpp b/common/include/common_defs.hpp
index d8e3e6c..8a61ff3 100644
--- a/common/include/common_defs.hpp
+++ b/common/include/common_defs.hpp
@@ -91,6 +91,23 @@ static inline void write(std::ostream& os, const T* ptr, 
size_t size_bytes) {
   os.write(reinterpret_cast<const char*>(ptr), size_bytes);
 }
 
+template<typename T>
+T byteswap(T value) {
+  char* ptr = static_cast<char*>(static_cast<void*>(&value));
+  const int len = sizeof(T);
+  for (size_t i = 0; i < len / 2; ++i) {
+    std::swap(ptr[i], ptr[len - i - 1]);
+  }
+  return value;
+}
+
+template<typename T>
+static inline T read_big_endian(std::istream& is) {
+  T value;
+  is.read(reinterpret_cast<char*>(&value), sizeof(T));
+  return byteswap(value);
+}
+
 // wrapper for iterators to implement operator-> returning temporary value
 template<typename T>
 class return_value_holder {
diff --git a/tdigest/include/tdigest.hpp b/tdigest/include/tdigest.hpp
index 357f203..d60a396 100644
--- a/tdigest/include/tdigest.hpp
+++ b/tdigest/include/tdigest.hpp
@@ -228,6 +228,9 @@ private:
   static const uint8_t SERIAL_VERSION = 1;
   static const uint8_t SKETCH_TYPE = 20;
 
+  static const uint8_t COMPAT_DOUBLE = 1;
+  static const uint8_t COMPAT_FLOAT = 2;
+
   enum flags { IS_EMPTY, REVERSE_MERGE };
 
   // for deserialize
@@ -238,6 +241,10 @@ private:
   void merge_new_values(uint16_t k);
 
   static double weighted_average(double x1, double w1, double x2, double w2);
+
+  // for compatibility with format of the reference implementation
+  static tdigest deserialize_compat(std::istream& is, const Allocator& 
allocator = Allocator());
+  static tdigest deserialize_compat(const void* bytes, size_t size, const 
Allocator& allocator = Allocator());
 };
 
 } /* namespace datasketches */
diff --git a/tdigest/include/tdigest_impl.hpp b/tdigest/include/tdigest_impl.hpp
index c47202e..fa301a1 100644
--- a/tdigest/include/tdigest_impl.hpp
+++ b/tdigest/include/tdigest_impl.hpp
@@ -352,6 +352,7 @@ tdigest<T, A> tdigest<T, A>::deserialize(std::istream& is, 
const A& allocator) {
   const auto serial_version = read<uint8_t>(is);
   const auto sketch_type = read<uint8_t>(is);
   if (sketch_type != SKETCH_TYPE) {
+    if (preamble_longs == 0 && serial_version == 0 && sketch_type == 0) return 
deserialize_compat(is, allocator);
     throw std::invalid_argument("sketch type mismatch: expected " + 
std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type));
   }
   if (serial_version != SERIAL_VERSION) {
@@ -391,6 +392,7 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* bytes, 
size_t size, const A
   const uint8_t serial_version = *ptr++;
   const uint8_t sketch_type = *ptr++;
   if (sketch_type != SKETCH_TYPE) {
+    if (preamble_longs == 0 && serial_version == 0 && sketch_type == 0) return 
deserialize_compat(ptr, end_ptr - ptr, allocator);
     throw std::invalid_argument("sketch type mismatch: expected " + 
std::to_string(SKETCH_TYPE) + ", actual " + std::to_string(sketch_type));
   }
   if (serial_version != SERIAL_VERSION) {
@@ -426,6 +428,111 @@ tdigest<T, A> tdigest<T, A>::deserialize(const void* 
bytes, size_t size, const A
   return tdigest(reverse_merge, k, min, max, std::move(centroids), 
total_weight, allocator);
 }
 
+template<typename T, typename A>
+tdigest<T, A> tdigest<T, A>::deserialize_compat(std::istream& is, const A& 
allocator) {
+  const auto type = read<uint8_t>(is);
+  if (type != COMPAT_DOUBLE && type != COMPAT_FLOAT) {
+    throw std::invalid_argument("unexpected sketch preamble: 0 0 0 " + 
std::to_string(type));
+  }
+  if (type == COMPAT_DOUBLE) {
+    const auto min = read_big_endian<double>(is);
+    const auto max = read_big_endian<double>(is);
+    const auto k = static_cast<uint16_t>(read_big_endian<double>(is));
+    const auto num_centroids = read_big_endian<uint32_t>(is);
+    vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
+    uint64_t total_weight = 0;
+    for (auto& c: centroids) {
+      const uint64_t weight = 
static_cast<uint64_t>(read_big_endian<double>(is));
+      const auto mean = read_big_endian<double>(is);
+      c = centroid(mean, weight);
+      total_weight += weight;
+    }
+    return tdigest(false, k, min, max, std::move(centroids), total_weight, 
allocator);
+  }
+  // compatibility with asSmallBytes()
+  const auto min = read_big_endian<double>(is); // reference implementation 
uses doubles for min and max
+  const auto max = read_big_endian<double>(is);
+  const auto k = static_cast<uint16_t>(read_big_endian<float>(is));
+  read<uint32_t>(is); // unused
+  const auto num_centroids = read_big_endian<uint16_t>(is);
+  vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
+  uint64_t total_weight = 0;
+  for (auto& c: centroids) {
+    const uint64_t weight = static_cast<uint64_t>(read_big_endian<float>(is));
+    const auto mean = read_big_endian<float>(is);
+    c = centroid(mean, weight);
+    total_weight += weight;
+  }
+  return tdigest(false, k, min, max, std::move(centroids), total_weight, 
allocator);
+}
+
+template<typename T, typename A>
+tdigest<T, A> tdigest<T, A>::deserialize_compat(const void* bytes, size_t 
size, const A& allocator) {
+  const char* ptr = static_cast<const char*>(bytes);
+  const auto type = *ptr++;
+  if (type != COMPAT_DOUBLE && type != COMPAT_FLOAT) {
+    throw std::invalid_argument("unexpected sketch preamble: 0 0 0 " + 
std::to_string(type));
+  }
+  const char* end_ptr = static_cast<const char*>(bytes) + size;
+  if (type == COMPAT_DOUBLE) {
+    ensure_minimum_memory(end_ptr - ptr, sizeof(double) * 3 + 
sizeof(uint32_t));
+    double min;
+    ptr += copy_from_mem(ptr, min);
+    min = byteswap(min);
+    double max;
+    ptr += copy_from_mem(ptr, max);
+    max = byteswap(max);
+    double k_double;
+    ptr += copy_from_mem(ptr, k_double);
+    const uint16_t k = static_cast<uint16_t>(byteswap(k_double));
+    uint32_t num_centroids;
+    ptr += copy_from_mem(ptr, num_centroids);
+    num_centroids = byteswap(num_centroids);
+    ensure_minimum_memory(end_ptr - ptr, sizeof(double) * num_centroids * 2);
+    vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
+    uint64_t total_weight = 0;
+    for (auto& c: centroids) {
+      double weight;
+      ptr += copy_from_mem(ptr, weight);
+      weight = byteswap(weight);
+      double mean;
+      ptr += copy_from_mem(ptr, mean);
+      mean = byteswap(mean);
+      c = centroid(mean, static_cast<uint64_t>(weight));
+      total_weight += static_cast<uint64_t>(weight);
+    }
+    return tdigest(false, k, min, max, std::move(centroids), total_weight, 
allocator);
+  }
+  ensure_minimum_memory(end_ptr - ptr, sizeof(double) * 2 + sizeof(float) + 
sizeof(uint16_t) * 3);
+  double min;
+  ptr += copy_from_mem(ptr, min);
+  min = byteswap(min);
+  double max;
+  ptr += copy_from_mem(ptr, max);
+  max = byteswap(max);
+  float k_float;
+  ptr += copy_from_mem(ptr, k_float);
+  const uint16_t k = static_cast<uint16_t>(byteswap(k_float));
+  ptr += sizeof(uint32_t); // unused
+  uint16_t num_centroids;
+  ptr += copy_from_mem(ptr, num_centroids);
+  num_centroids = byteswap(num_centroids);
+  ensure_minimum_memory(end_ptr - ptr, sizeof(float) * num_centroids * 2);
+  vector_centroid centroids(num_centroids, centroid(0, 0), allocator);
+  uint64_t total_weight = 0;
+  for (auto& c: centroids) {
+    float weight;
+    ptr += copy_from_mem(ptr, weight);
+    weight = byteswap(weight);
+    float mean;
+    ptr += copy_from_mem(ptr, mean);
+    mean = byteswap(mean);
+    c = centroid(mean, static_cast<uint64_t>(weight));
+    total_weight += static_cast<uint64_t>(weight);
+  }
+  return tdigest(false, k, min, max, std::move(centroids), total_weight, 
allocator);
+}
+
 template<typename T, typename A>
 tdigest<T, A>::tdigest(bool reverse_merge, uint16_t k, T min, T max, 
vector_centroid&& centroids, uint64_t total_weight, const A& allocator):
 allocator_(allocator),
diff --git a/tdigest/test/tdigest_test.cpp b/tdigest/test/tdigest_test.cpp
index b1627d5..cfed630 100644
--- a/tdigest/test/tdigest_test.cpp
+++ b/tdigest/test/tdigest_test.cpp
@@ -19,6 +19,7 @@
 
 #include <catch2/catch.hpp>
 #include <iostream>
+#include <fstream>
 
 #include "tdigest.hpp"
 
@@ -229,4 +230,70 @@ TEST_CASE("serialize deserialize steam and bytes 
equivalence", "[tdigest]") {
   REQUIRE(deserialized_td1.get_quantile(0.5) == 
deserialized_td2.get_quantile(0.5));
 }
 
+TEST_CASE("deserialize from reference implementation stream double", 
"[tdigest]") {
+  std::ifstream is;
+  is.exceptions(std::ios::failbit | std::ios::badbit);
+  is.open(std::string(TEST_BINARY_INPUT_PATH) + 
"tdigest_ref_k100_n10000_double.sk", std::ios::binary);
+  const auto td = tdigest<double>::deserialize(is);
+  const size_t n = 10000;
+  REQUIRE(td.get_total_weight() == n);
+  REQUIRE(td.get_min_value() == 0);
+  REQUIRE(td.get_max_value() == n - 1);
+  REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001));
+  REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001));
+  REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001));
+  REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001));
+  REQUIRE(td.get_rank(n) == 1);
+}
+
+TEST_CASE("deserialize from reference implementation stream float", 
"[tdigest]") {
+  std::ifstream is;
+  is.exceptions(std::ios::failbit | std::ios::badbit);
+  is.open(std::string(TEST_BINARY_INPUT_PATH) + 
"tdigest_ref_k100_n10000_float.sk", std::ios::binary);
+  const auto td = tdigest<float>::deserialize(is);
+  const size_t n = 10000;
+  REQUIRE(td.get_total_weight() == n);
+  REQUIRE(td.get_min_value() == 0);
+  REQUIRE(td.get_max_value() == n - 1);
+  REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001));
+  REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001));
+  REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001));
+  REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001));
+  REQUIRE(td.get_rank(n) == 1);
+}
+
+TEST_CASE("deserialize from reference implementation bytes double", 
"[tdigest]") {
+  std::ifstream is;
+  is.exceptions(std::ios::failbit | std::ios::badbit);
+  is.open(std::string(TEST_BINARY_INPUT_PATH) + 
"tdigest_ref_k100_n10000_double.sk", std::ios::binary);
+  std::vector<char> bytes((std::istreambuf_iterator<char>(is)), 
(std::istreambuf_iterator<char>()));
+  const auto td = tdigest<double>::deserialize(bytes.data(), bytes.size());
+  const size_t n = 10000;
+  REQUIRE(td.get_total_weight() == n);
+  REQUIRE(td.get_min_value() == 0);
+  REQUIRE(td.get_max_value() == n - 1);
+  REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001));
+  REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001));
+  REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001));
+  REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001));
+  REQUIRE(td.get_rank(n) == 1);
+}
+
+TEST_CASE("deserialize from reference implementation bytes float", 
"[tdigest]") {
+  std::ifstream is;
+  is.exceptions(std::ios::failbit | std::ios::badbit);
+  is.open(std::string(TEST_BINARY_INPUT_PATH) + 
"tdigest_ref_k100_n10000_float.sk", std::ios::binary);
+  std::vector<char> bytes((std::istreambuf_iterator<char>(is)), 
(std::istreambuf_iterator<char>()));
+  const auto td = tdigest<double>::deserialize(bytes.data(), bytes.size());
+  const size_t n = 10000;
+  REQUIRE(td.get_total_weight() == n);
+  REQUIRE(td.get_min_value() == 0);
+  REQUIRE(td.get_max_value() == n - 1);
+  REQUIRE(td.get_rank(0) == Approx(0).margin(0.0001));
+  REQUIRE(td.get_rank(n / 4) == Approx(0.25).margin(0.0001));
+  REQUIRE(td.get_rank(n / 2) == Approx(0.5).margin(0.0001));
+  REQUIRE(td.get_rank(n * 3 / 4) == Approx(0.75).margin(0.0001));
+  REQUIRE(td.get_rank(n) == 1);
+}
+
 } /* namespace datasketches */


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to