tisonkun commented on code in PR #1:
URL: https://github.com/apache/datasketches-rust/pull/1#discussion_r2617093566


##########
src/hll/sketch.rs:
##########
@@ -0,0 +1,347 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! HyperLogLog sketch implementation
+//!
+//! This module provides the main [`HllSketch`] struct, which is the primary 
interface
+//! for creating and using HLL sketches for cardinality estimation.
+
+use std::hash::Hash;
+
+use crate::error::{SerdeError, SerdeResult};
+use crate::hll::array4::Array4;
+use crate::hll::array6::Array6;
+use crate::hll::array8::Array8;
+use crate::hll::container::Container;
+use crate::hll::hash_set::HashSet;
+use crate::hll::list::List;
+use crate::hll::serialization::*;
+use crate::hll::{HllType, RESIZE_DENOM, RESIZE_NUMER, coupon};
+
+/// Current sketch mode
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum CurMode {
+    List = 0,
+    Set = 1,
+    Hll = 2,
+}
+
+/// A HyperLogLog sketch.
+///
+/// See the [module level documentation](self) for more.
+#[derive(Debug, Clone, PartialEq)]
+pub struct HllSketch {
+    lg_config_k: u8,
+    mode: Mode,
+}
+
+#[derive(Debug, Clone, PartialEq)]
+enum Mode {
+    List { list: List, hll_type: HllType },
+    Set { set: HashSet, hll_type: HllType },
+    Array4(Array4),
+    Array6(Array6),
+    Array8(Array8),
+}
+
+impl HllSketch {
+    /// Create a new HLL sketch
+    ///
+    /// # Arguments
+    ///
+    /// * `lg_config_k` - Log2 of the number of buckets (K). Must be in [4, 
21].
+    ///   - lg_k=4: 16 buckets, ~26% relative error
+    ///   - lg_k=12: 4096 buckets, ~1.6% relative error (common choice)
+    ///   - lg_k=21: 2M buckets, ~0.4% relative error
+    /// * `hll_type` - Target HLL array type (Hll4, Hll6, or Hll8)
+    ///
+    /// # Panics
+    ///
+    /// Panics if lg_config_k is not in [4, 21]
+    pub fn new(lg_config_k: u8, hll_type: HllType) -> Self {
+        assert!(
+            lg_config_k > 4 && lg_config_k < 21,
+            "lg_config_k must be in [4, 21]"
+        );
+
+        let list = List::default();
+
+        Self {
+            lg_config_k,
+            mode: Mode::List { list, hll_type },
+        }
+    }
+
+    /// Get the configured lg_config_k
+    pub fn lg_config_k(&self) -> u8 {
+        self.lg_config_k
+    }
+
+    /// Update the sketch with a value
+    ///
+    /// This accepts any type that implements `Hash`. The value is hashed
+    /// and converted to a coupon, which is then inserted into the sketch.
+    pub fn update<T: Hash>(&mut self, value: T) {
+        let coupon = coupon(value);
+        self.update_with_coupon(coupon);
+    }
+
+    /// Update the sketch with a raw coupon value
+    ///
+    /// This is useful when you've already computed the coupon externally,
+    /// or when deserializing and replaying coupons.
+    fn update_with_coupon(&mut self, coupon: u32) {
+        match &mut self.mode {
+            Mode::List { list, hll_type } => {
+                list.update(coupon);
+                let should_promote = list.container().is_full();
+                if should_promote {
+                    self.mode = if self.lg_config_k < 8 {
+                        promote_container_to_array(list.container(), 
*hll_type, self.lg_config_k)
+                    } else {
+                        promote_container_to_set(list.container(), *hll_type)
+                    }
+                }
+            }
+            Mode::Set { set, hll_type } => {
+                set.update(coupon);
+                let should_promote = RESIZE_DENOM as usize * 
set.container().len()
+                    > RESIZE_NUMER as usize * set.container().capacity();
+                if should_promote {
+                    self.mode = if set.container().lg_size() == 
self.lg_config_k as usize - 3 {
+                        promote_container_to_array(set.container(), *hll_type, 
self.lg_config_k)
+                    } else {
+                        grow_set(set, *hll_type)
+                    }
+                }
+            }
+            Mode::Array4(arr) => arr.update(coupon),
+            Mode::Array6(arr) => arr.update(coupon),
+            Mode::Array8(arr) => arr.update(coupon),
+        }
+    }
+
+    /// Get the current cardinality estimate
+    pub fn estimate(&self) -> f64 {
+        match &self.mode {
+            Mode::List { list, .. } => list.container().estimate(),
+            Mode::Set { set, .. } => set.container().estimate(),
+            Mode::Array4(arr) => arr.estimate(),
+            Mode::Array6(arr) => arr.estimate(),
+            Mode::Array8(arr) => arr.estimate(),
+        }
+    }
+
+    /// Get upper bound for cardinality estimate
+    ///
+    /// Returns the upper confidence bound for the cardinality estimate based 
on
+    /// the number of standard deviations requested.
+    pub fn upper_bound(&self, num_std_dev: u8) -> f64 {
+        match &self.mode {
+            Mode::List { list, .. } => 
list.container().upper_bound(num_std_dev),
+            Mode::Set { set, .. } => set.container().upper_bound(num_std_dev),
+            Mode::Array4(arr) => arr.upper_bound(num_std_dev),
+            Mode::Array6(arr) => arr.upper_bound(num_std_dev),
+            Mode::Array8(arr) => arr.upper_bound(num_std_dev),
+        }
+    }
+
+    /// Get lower bound for cardinality estimate
+    ///
+    /// Returns the lower confidence bound for the cardinality estimate based 
on
+    /// the number of standard deviations requested.
+    pub fn lower_bound(&self, num_std_dev: u8) -> f64 {
+        match &self.mode {
+            Mode::List { list, .. } => 
list.container().lower_bound(num_std_dev),
+            Mode::Set { set, .. } => set.container().lower_bound(num_std_dev),
+            Mode::Array4(arr) => arr.lower_bound(num_std_dev),
+            Mode::Array6(arr) => arr.lower_bound(num_std_dev),
+            Mode::Array8(arr) => arr.lower_bound(num_std_dev),
+        }
+    }
+
+    /// Deserializes an HLL sketch from bytes
+    pub fn deserialize(bytes: &[u8]) -> SerdeResult<HllSketch> {
+        if bytes.len() < 8 {
+            return Err(SerdeError::InsufficientData(
+                "sketch data too short (< 8 bytes)".to_string(),
+            ));
+        }
+
+        // Read and validate preamble
+        let preamble_ints = bytes[PREAMBLE_INTS_BYTE];
+        let ser_ver = bytes[SER_VER_BYTE];
+        let family_id = bytes[FAMILY_BYTE];
+        let lg_config_k = bytes[LG_K_BYTE];
+        let flags = bytes[FLAGS_BYTE];
+        let mode_byte = bytes[MODE_BYTE];
+
+        // Verify family ID
+        if family_id != HLL_FAMILY_ID {
+            return Err(SerdeError::InvalidFamily(format!(
+                "expected {} (HLL), got {}",
+                HLL_FAMILY_ID, family_id
+            )));
+        }
+
+        // Verify serialization version
+        if ser_ver != SER_VER {
+            return Err(SerdeError::UnsupportedVersion(format!(
+                "expected {}, got {}",
+                SER_VER, ser_ver
+            )));
+        }
+
+        // Verify lg_k range (4-21 are valid)
+        if !(4..=21).contains(&lg_config_k) {
+            return Err(SerdeError::InvalidParameter(format!(
+                "lg_k must be in [4; 21], got {}",
+                lg_config_k
+            )));
+        }
+
+        // Extract mode and type
+        let cur_mode = match extract_cur_mode(mode_byte) {
+            CUR_MODE_LIST => CurMode::List,
+            CUR_MODE_SET => CurMode::Set,
+            CUR_MODE_HLL => CurMode::Hll,
+            mode => return Err(SerdeError::MalformedData(format!("invalid 
mode: {}", mode))),
+        };
+        let hll_type = match extract_tgt_hll_type(mode_byte) {
+            TGT_HLL4 => HllType::Hll4,
+            TGT_HLL6 => HllType::Hll6,
+            TGT_HLL8 => HllType::Hll8,
+            hll_type => {
+                return Err(SerdeError::MalformedData(format!(
+                    "invalid HLL type: {}",
+                    hll_type
+                )));
+            }
+        };
+        let empty = (flags & EMPTY_FLAG_MASK) != 0;
+        let compact = (flags & COMPACT_FLAG_MASK) != 0;
+        let ooo = (flags & OUT_OF_ORDER_FLAG_MASK) != 0;
+
+        // Deserialize based on mode
+        let mode =
+            match cur_mode {
+                CurMode::List => {
+                    if preamble_ints != LIST_PREINTS {
+                        return Err(SerdeError::MalformedData(format!(
+                            "LIST mode preamble: expected {}, got {}",
+                            LIST_PREINTS, preamble_ints
+                        )));
+                    }
+
+                    let list = List::deserialize(bytes, empty, compact)?;
+                    Mode::List { list, hll_type }
+                }
+                CurMode::Set => {
+                    if preamble_ints != HASH_SET_PREINTS {
+                        return Err(SerdeError::MalformedData(format!(
+                            "SET mode preamble: expected {}, got {}",
+                            HASH_SET_PREINTS, preamble_ints
+                        )));
+                    }
+
+                    let set = HashSet::deserialize(bytes, compact)?;
+                    Mode::Set { set, hll_type }
+                }
+                CurMode::Hll => {
+                    if preamble_ints != HLL_PREINTS {
+                        return Err(SerdeError::MalformedData(format!(
+                            "HLL mode preamble: expected {}, got {}",
+                            HLL_PREINTS, preamble_ints
+                        )));
+                    }
+
+                    match hll_type {
+                        HllType::Hll4 => Array4::deserialize(bytes, 
lg_config_k, compact, ooo)
+                            .map(Mode::Array4)?,
+                        HllType::Hll6 => Array6::deserialize(bytes, 
lg_config_k, compact, ooo)
+                            .map(Mode::Array6)?,
+                        HllType::Hll8 => Array8::deserialize(bytes, 
lg_config_k, compact, ooo)
+                            .map(Mode::Array8)?,
+                    }
+                }
+            };
+
+        Ok(HllSketch { lg_config_k, mode })
+    }
+
+    /// Serializes the HLL sketch to bytes
+    pub fn serialize(&self) -> SerdeResult<Vec<u8>> {
+        match &self.mode {
+            Mode::List { list, hll_type } => list.serialize(self.lg_config_k, 
*hll_type),
+            Mode::Set { set, hll_type } => set.serialize(self.lg_config_k, 
*hll_type),
+            Mode::Array4(arr) => arr.serialize(self.lg_config_k),
+            Mode::Array6(arr) => arr.serialize(self.lg_config_k),
+            Mode::Array8(arr) => arr.serialize(self.lg_config_k),
+        }
+    }
+}

Review Comment:
   If we are dumping the HllSketch to an always growable `Vec<u8>`, I think we 
are infaillible here:
   
   ```diff
   diff --git a/src/hll/array4.rs b/src/hll/array4.rs
   index 0ab2fce..dbcda53 100644
   --- a/src/hll/array4.rs
   +++ b/src/hll/array4.rs
   @@ -21,6 +21,7 @@
    //! When values exceed 4 bits after cur_min offset, they're stored in an 
auxiliary hash map.
    
    use super::aux_map::AuxMap;
   +use crate::error::SerdeResult;
    use crate::hll::estimator::HipEstimator;
    use crate::hll::{get_slot, get_value};
    
   @@ -249,7 +250,7 @@ impl Array4 {
            lg_config_k: u8,
            compact: bool,
            ooo: bool,
   -    ) -> crate::error::SerdeResult<Self> {
   +    ) -> SerdeResult<Self> {
            use crate::error::SerdeError;
            use crate::hll::serialization::*;
            use crate::hll::{get_slot, get_value};
   @@ -325,7 +326,7 @@ impl Array4 {
        /// Serialize Array4 to bytes
        ///
        /// Produces full HLL preamble (40 bytes) followed by packed 4-bit data 
and optional aux map.
   -    pub fn serialize(&self, lg_config_k: u8) -> 
crate::error::SerdeResult<Vec<u8>> {
   +    pub fn serialize(&self, lg_config_k: u8) -> Vec<u8> {
            use crate::hll::pack_coupon;
            use crate::hll::serialization::*;
    
   @@ -384,7 +385,7 @@ impl Array4 {
                write_u32_le(&mut bytes, offset, coupon);
            }
    
   -        Ok(bytes)
   +        bytes
        }
    }
    
   diff --git a/src/hll/array6.rs b/src/hll/array6.rs
   index 27a5997..4b5d8a4 100644
   --- a/src/hll/array6.rs
   +++ b/src/hll/array6.rs
   @@ -21,6 +21,7 @@
    //! This is sufficient for most HLL use cases without needing exception 
handling or
    //! cur_min optimization like Array4.
    
   +use crate::error::SerdeResult;
    use crate::hll::estimator::HipEstimator;
    use crate::hll::{get_slot, get_value};
    
   @@ -149,7 +150,7 @@ impl Array6 {
            lg_config_k: u8,
            compact: bool,
            ooo: bool,
   -    ) -> crate::error::SerdeResult<Self> {
   +    ) -> SerdeResult<Self> {
            use crate::error::SerdeError;
            use crate::hll::serialization::*;
    
   @@ -201,7 +202,7 @@ impl Array6 {
        /// Serialize Array6 to bytes
        ///
        /// Produces full HLL preamble (40 bytes) followed by packed 6-bit data.
   -    pub fn serialize(&self, lg_config_k: u8) -> 
crate::error::SerdeResult<Vec<u8>> {
   +    pub fn serialize(&self, lg_config_k: u8) -> Vec<u8> {
            use crate::hll::serialization::*;
    
            let k = 1 << lg_config_k;
   @@ -243,7 +244,7 @@ impl Array6 {
            // Write packed byte array
            bytes[HLL_BYTE_ARR_START..].copy_from_slice(&self.bytes);
    
   -        Ok(bytes)
   +        bytes
        }
    }
    
   diff --git a/src/hll/array8.rs b/src/hll/array8.rs
   index 6d187d5..4912dc8 100644
   --- a/src/hll/array8.rs
   +++ b/src/hll/array8.rs
   @@ -20,6 +20,7 @@
    //! Array8 is the simplest HLL array implementation, storing one byte per 
slot.
    //! This provides the maximum value range (0-255) with no bit-packing 
complexity.
    
   +use crate::error::SerdeResult;
    use crate::hll::estimator::HipEstimator;
    use crate::hll::{get_slot, get_value};
    
   @@ -119,7 +120,7 @@ impl Array8 {
            lg_config_k: u8,
            compact: bool,
            ooo: bool,
   -    ) -> crate::error::SerdeResult<Self> {
   +    ) -> SerdeResult<Self> {
            use crate::error::SerdeError;
            use crate::hll::serialization::*;
    
   @@ -170,7 +171,7 @@ impl Array8 {
        /// Serialize Array8 to bytes
        ///
        /// Produces full HLL preamble (40 bytes) followed by k bytes of data.
   -    pub fn serialize(&self, lg_config_k: u8) -> 
crate::error::SerdeResult<Vec<u8>> {
   +    pub fn serialize(&self, lg_config_k: u8) -> Vec<u8> {
            use crate::hll::serialization::*;
    
            let k = 1 << lg_config_k;
   @@ -211,7 +212,7 @@ impl Array8 {
            // Write byte array
            bytes[HLL_BYTE_ARR_START..].copy_from_slice(&self.bytes);
    
   -        Ok(bytes)
   +        bytes
        }
    }
    
   diff --git a/src/hll/hash_set.rs b/src/hll/hash_set.rs
   index 9859730..22a583f 100644
   --- a/src/hll/hash_set.rs
   +++ b/src/hll/hash_set.rs
   @@ -136,7 +136,7 @@ impl HashSet {
        }
    
        /// Serialize a HashSet to bytes
   -    pub fn serialize(&self, lg_config_k: u8, hll_type: HllType) -> 
SerdeResult<Vec<u8>> {
   +    pub fn serialize(&self, lg_config_k: u8, hll_type: HllType) -> Vec<u8> {
            let compact = true; // Always use compact format
            let coupon_count = self.container.len();
            let lg_arr = self.container.lg_size();
   @@ -191,6 +191,6 @@ impl HashSet {
                }
            }
    
   -        Ok(bytes)
   +        bytes
        }
    }
   diff --git a/src/hll/list.rs b/src/hll/list.rs
   index 5810750..305e38d 100644
   --- a/src/hll/list.rs
   +++ b/src/hll/list.rs
   @@ -98,7 +98,7 @@ impl List {
        }
    
        /// Serialize a List to bytes
   -    pub fn serialize(&self, lg_config_k: u8, hll_type: HllType) -> 
SerdeResult<Vec<u8>> {
   +    pub fn serialize(&self, lg_config_k: u8, hll_type: HllType) -> Vec<u8> {
            let compact = true; // Always use compact format
            let empty = self.container.len() == 0;
            let coupon_count = self.container.len();
   @@ -149,6 +149,6 @@ impl List {
                }
            }
    
   -        Ok(bytes)
   +        bytes
        }
    }
   diff --git a/src/hll/sketch.rs b/src/hll/sketch.rs
   index 51ead14..dfae923 100644
   --- a/src/hll/sketch.rs
   +++ b/src/hll/sketch.rs
   @@ -284,7 +284,7 @@ impl HllSketch {
        }
    
        /// Serializes the HLL sketch to bytes
   -    pub fn serialize(&self) -> SerdeResult<Vec<u8>> {
   +    pub fn serialize(&self) -> Vec<u8> {
            match &self.mode {
                Mode::List { list, hll_type } => 
list.serialize(self.lg_config_k, *hll_type),
                Mode::Set { set, hll_type } => set.serialize(self.lg_config_k, 
*hll_type),
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to