This is an automated email from the ASF dual-hosted git repository.

alsay pushed a commit to branch tdigest_pmf_and_cdf
in repository https://gitbox.apache.org/repos/asf/datasketches-cpp.git

commit f0d4cb7c7113ca461801ca5d52934a61872f9ab6
Author: AlexanderSaydakov <[email protected]>
AuthorDate: Tue Oct 22 18:07:47 2024 -0700

    implemented get_PMF() and get_CDF()
---
 tdigest/include/tdigest.hpp      | 50 ++++++++++++++++++++++++++++++++++++++++
 tdigest/include/tdigest_impl.hpp | 36 +++++++++++++++++++++++++++++
 tdigest/test/tdigest_test.cpp    | 15 +++++++++---
 3 files changed, 98 insertions(+), 3 deletions(-)

diff --git a/tdigest/include/tdigest.hpp b/tdigest/include/tdigest.hpp
index bb4f737..d33084e 100644
--- a/tdigest/include/tdigest.hpp
+++ b/tdigest/include/tdigest.hpp
@@ -89,6 +89,7 @@ public:
   using vector_t = std::vector<T, Allocator>;
   using vector_centroid = std::vector<centroid, typename 
std::allocator_traits<Allocator>::template rebind_alloc<centroid>>;
   using vector_bytes = std::vector<uint8_t, typename 
std::allocator_traits<Allocator>::template rebind_alloc<uint8_t>>;
+  using vector_double = std::vector<double, typename 
std::allocator_traits<Allocator>::template rebind_alloc<double>>;
 
   struct centroid_cmp {
     centroid_cmp() {}
@@ -142,8 +143,17 @@ public:
    */
   uint64_t get_total_weight() const;
 
+  /**
+   * Returns an instance of the allocator for this t-Digest.
+   * @return allocator
+   */
+  Allocator get_allocator() const;
+
   /**
    * Compute approximate normalized rank of the given value.
+   *
+   * <p>If the sketch is empty this throws std::runtime_error.
+   *
    * @param value to be ranked
    * @return normalized rank (from 0 to 1 inclusive)
    */
@@ -151,11 +161,49 @@ public:
 
   /**
    * Compute approximate quantile value corresponding to the given normalized 
rank
+   *
+   * <p>If the sketch is empty this throws std::runtime_error.
+   *
    * @param rank normalized rank (from 0 to 1 inclusive)
    * @return quantile value corresponding to the given rank
    */
   T get_quantile(double rank) const;
 
+  /**
+   * Returns an approximation to the Probability Mass Function (PMF) of the 
input stream
+   * given a set of split points.
+   *
+   * <p>If the sketch is empty this throws std::runtime_error.
+   *
+   * @param split_points an array of <i>m</i> unique, monotonically increasing 
values
+   * that divide the input domain into <i>m+1</i> consecutive disjoint 
intervals (bins).
+   *
+   * @param size the number of split points in the array
+   *
+   * @return an array of m+1 doubles each of which is an approximation
+   * to the fraction of the input stream values (the mass) that fall into one 
of those intervals.
+   */
+  vector_double get_PMF(const T* split_points, uint32_t size) const;
+
+  /**
+   * Returns an approximation to the Cumulative Distribution Function (CDF), 
which is the
+   * cumulative analog of the PMF, of the input stream given a set of split 
points.
+   *
+   * <p>If the sketch is empty this throws std::runtime_error.
+   *
+   * @param split_points an array of <i>m</i> unique, monotonically increasing 
values
+   * that divide the input domain into <i>m+1</i> consecutive disjoint 
intervals.
+   *
+   * @param size the number of split points in the array
+   *
+   * @return an array of m+1 doubles, which are a consecutive approximation to 
the CDF
+   * of the input stream given the split_points. The value at array position j 
of the returned
+   * CDF array is the sum of the returned values in positions 0 through j of 
the returned PMF
+   * array. This can be viewed as array of ranks of the given split points 
plus one more value
+   * that is always 1.
+   */
+  vector_double get_CDF(const T* split_points, uint32_t size) const;
+
   /**
    * @return parameter k (compression) that was used to configure this t-Digest
    */
@@ -245,6 +293,8 @@ private:
   // for compatibility with format of the reference implementation
   static tdigest deserialize_compat(std::istream& is, const Allocator& 
allocator = Allocator());
   static tdigest deserialize_compat(const void* bytes, size_t size, const 
Allocator& allocator = Allocator());
+
+  static inline void check_split_points(const T* values, uint32_t size);
 };
 
 } /* namespace datasketches */
diff --git a/tdigest/include/tdigest_impl.hpp b/tdigest/include/tdigest_impl.hpp
index 165bda6..6e3ae1a 100644
--- a/tdigest/include/tdigest_impl.hpp
+++ b/tdigest/include/tdigest_impl.hpp
@@ -85,6 +85,11 @@ uint64_t tdigest<T, A>::get_total_weight() const {
   return centroids_weight_ + buffer_.size();
 }
 
+template<typename T, typename A>
+A tdigest<T, A>::get_allocator() const {
+  return buffer_.get_allocator();
+}
+
 template<typename T, typename A>
 double tdigest<T, A>::get_rank(T value) const {
   if (is_empty()) throw std::runtime_error("operation is undefined for an 
empty sketch");
@@ -191,6 +196,25 @@ T tdigest<T, A>::get_quantile(double rank) const {
   return weighted_average(centroids_.back().get_weight(), w1, max_, w2);
 }
 
+template<typename T, typename A>
+auto tdigest<T, A>::get_PMF(const T* split_points, uint32_t size) const -> 
vector_double {
+  auto buckets = get_CDF(split_points, size);
+  for (uint32_t i = size; i > 0; --i) {
+    buckets[i] -= buckets[i - 1];
+  }
+  return buckets;
+}
+
+template<typename T, typename A>
+auto tdigest<T, A>::get_CDF(const T* split_points, uint32_t size) const -> 
vector_double {
+  check_split_points(split_points, size);
+  vector_double ranks(get_allocator());
+  ranks.reserve(size + 1);
+  for (uint32_t i = 0; i < size; ++i) 
ranks.push_back(get_rank(split_points[i]));
+  ranks.push_back(1);
+  return ranks;
+}
+
 template<typename T, typename A>
 uint16_t tdigest<T, A>::get_k() const {
   return k_;
@@ -591,6 +615,18 @@ buffer_(std::move(buffer))
   buffer_.reserve(centroids_capacity_ * BUFFER_MULTIPLIER);
 }
 
+template<typename T, typename A>
+void tdigest<T, A>::check_split_points(const T* values, uint32_t size) {
+  for (uint32_t i = 0; i < size ; i++) {
+    if (std::isnan(values[i])) {
+      throw std::invalid_argument("Values must not be NaN");
+    }
+    if ((i < (size - 1)) && !(values[i] < values[i + 1])) {
+      throw std::invalid_argument("Values must be unique and monotonically 
increasing");
+    }
+  }
+}
+
 } /* namespace datasketches */
 
 #endif // _TDIGEST_IMPL_HPP_
diff --git a/tdigest/test/tdigest_test.cpp b/tdigest/test/tdigest_test.cpp
index bf64dbb..fc3f5d1 100644
--- a/tdigest/test/tdigest_test.cpp
+++ b/tdigest/test/tdigest_test.cpp
@@ -35,6 +35,9 @@ TEST_CASE("empty", "[tdigest]") {
   REQUIRE_THROWS_AS(td.get_max_value(), std::runtime_error);
   REQUIRE_THROWS_AS(td.get_rank(0), std::runtime_error);
   REQUIRE_THROWS_AS(td.get_quantile(0.5), std::runtime_error);
+  const double split_points[1] {0};
+  REQUIRE_THROWS_AS(td.get_PMF(split_points, 1), std::runtime_error);
+  REQUIRE_THROWS_AS(td.get_CDF(split_points, 1), std::runtime_error);
 }
 
 TEST_CASE("one value", "[tdigest]") {
@@ -56,9 +59,6 @@ TEST_CASE("many values", "[tdigest]") {
   const size_t n = 10000;
   tdigest_double td;
   for (size_t i = 0; i < n; ++i) td.update(i);
-//  std::cout << td.to_string(true);
-//  td.compress();
-//  std::cout << td.to_string(true);
   REQUIRE_FALSE(td.is_empty());
   REQUIRE(td.get_total_weight() == n);
   REQUIRE(td.get_min_value() == 0);
@@ -73,6 +73,15 @@ TEST_CASE("many values", "[tdigest]") {
   REQUIRE(td.get_quantile(0.9) == Approx(n * 0.9).epsilon(0.01));
   REQUIRE(td.get_quantile(0.95) == Approx(n * 0.95).epsilon(0.01));
   REQUIRE(td.get_quantile(1) == n - 1);
+  const double split_points[1] {n / 2};
+  const auto pmf = td.get_PMF(split_points, 1);
+  REQUIRE(pmf.size() == 2);
+  REQUIRE(pmf[0] == Approx(0.5).margin(0.0001));
+  REQUIRE(pmf[1] == Approx(0.5).margin(0.0001));
+  const auto cdf = td.get_CDF(split_points, 1);
+  REQUIRE(cdf.size() == 2);
+  REQUIRE(cdf[0] == Approx(0.5).margin(0.0001));
+  REQUIRE(cdf[1] == 1);
 }
 
 TEST_CASE("rank - two values", "[tdigest]") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to