This is an automated email from the ASF dual-hosted git repository.

leerho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datasketches-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new f5143c3  feat: add a generic error type  (#47)
f5143c3 is described below

commit f5143c311ae169b9c0bead1e47bd58881526f83e
Author: tison <[email protected]>
AuthorDate: Tue Dec 30 09:53:15 2025 +0800

    feat: add a generic error type  (#47)
    
    * feat: add a generic error type
    
    Signed-off-by: tison <[email protected]>
    
    * apply
    
    Signed-off-by: tison <[email protected]>
    
    * convenience constructors
    
    Signed-off-by: tison <[email protected]>
    
    * more
    
    Signed-off-by: tison <[email protected]>
    
    * rename error kind
    
    Signed-off-by: tison <[email protected]>
    
    * no need source for now
    
    Signed-off-by: tison <[email protected]>
    
    ---------
    
    Signed-off-by: tison <[email protected]>
---
 Cargo.lock                            |   7 ++
 Cargo.toml                            |   1 +
 datasketches/Cargo.toml               |   1 +
 datasketches/src/countmin/sketch.rs   |  58 +++++-----
 datasketches/src/error.rs             | 167 ++++++++++++++++++++++++----
 datasketches/src/hll/array4.rs        |  10 +-
 datasketches/src/hll/array6.rs        |   8 +-
 datasketches/src/hll/array8.rs        |   8 +-
 datasketches/src/hll/hash_set.rs      |  10 +-
 datasketches/src/hll/list.rs          |   8 +-
 datasketches/src/hll/serialization.rs |   2 +-
 datasketches/src/hll/sketch.rs        |  42 +++-----
 datasketches/src/tdigest/sketch.rs    | 197 ++++++++++++++++------------------
 13 files changed, 309 insertions(+), 210 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index ba741cd..66212a0 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -61,6 +61,12 @@ dependencies = [
  "windows-sys",
 ]
 
+[[package]]
+name = "anyhow"
+version = "1.0.100"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
+
 [[package]]
 name = "autocfg"
 version = "1.5.0"
@@ -129,6 +135,7 @@ checksum = 
"b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
 name = "datasketches"
 version = "0.1.0"
 dependencies = [
+ "anyhow",
  "byteorder",
  "googletest",
 ]
diff --git a/Cargo.toml b/Cargo.toml
index e991913..4c93ee2 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -32,6 +32,7 @@ rust-version = "1.85.0"
 datasketches = { path = "datasketches" }
 
 # Crates.io dependencies
+anyhow = { version = "1.0.100" }
 byteorder = { version = "1.5.0" }
 clap = { version = "4.5.20", features = ["derive"] }
 googletest = { version = "0.14.2" }
diff --git a/datasketches/Cargo.toml b/datasketches/Cargo.toml
index bddcb7e..8a150ff 100644
--- a/datasketches/Cargo.toml
+++ b/datasketches/Cargo.toml
@@ -35,6 +35,7 @@ all-features = true
 rustdoc-args = ["--cfg", "docsrs"]
 
 [dependencies]
+anyhow = { workspace = true }
 byteorder = { workspace = true }
 
 [dev-dependencies]
diff --git a/datasketches/src/countmin/sketch.rs 
b/datasketches/src/countmin/sketch.rs
index 3fabfb4..e98e96e 100644
--- a/datasketches/src/countmin/sketch.rs
+++ b/datasketches/src/countmin/sketch.rs
@@ -29,7 +29,7 @@ use crate::countmin::serialization::LONG_SIZE_BYTES;
 use crate::countmin::serialization::PREAMBLE_LONGS_SHORT;
 use crate::countmin::serialization::SERIAL_VERSION;
 use crate::countmin::serialization::compute_seed_hash;
-use crate::error::SerdeError;
+use crate::error::Error;
 use crate::hash::MurmurHash3X64128;
 
 const MAX_TABLE_ENTRIES: usize = 1 << 30;
@@ -231,14 +231,14 @@ impl CountMinSketch {
     }
 
     /// Deserializes a sketch from bytes using the default seed.
-    pub fn deserialize(bytes: &[u8]) -> Result<Self, SerdeError> {
+    pub fn deserialize(bytes: &[u8]) -> Result<Self, Error> {
         Self::deserialize_with_seed(bytes, DEFAULT_SEED)
     }
 
     /// Deserializes a sketch from bytes using the provided seed.
-    pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result<Self, 
SerdeError> {
-        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
SerdeError {
-            move |_| SerdeError::InsufficientData(tag.to_string())
+    pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result<Self, 
Error> {
+        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
Error {
+            move |_| Error::insufficient_data(tag)
         }
 
         let mut cursor = Cursor::new(bytes);
@@ -249,21 +249,23 @@ impl CountMinSketch {
         cursor.read_u32::<LE>().map_err(make_error("unused32"))?;
 
         if family_id != COUNTMIN_FAMILY_ID {
-            return Err(SerdeError::InvalidFamily(format!(
-                "expected {} (CountMinSketch), got {}",
-                COUNTMIN_FAMILY_ID, family_id
-            )));
+            return Err(Error::invalid_family(
+                COUNTMIN_FAMILY_ID,
+                family_id,
+                "CountMinSketch",
+            ));
         }
         if serial_version != SERIAL_VERSION {
-            return Err(SerdeError::UnsupportedVersion(format!(
-                "expected {}, got {}",
-                SERIAL_VERSION, serial_version
-            )));
+            return Err(Error::unsupported_serial_version(
+                SERIAL_VERSION,
+                serial_version,
+            ));
         }
         if preamble_longs != PREAMBLE_LONGS_SHORT {
-            return Err(SerdeError::MalformedData(format!(
-                "unsupported preamble_longs {preamble_longs}"
-            )));
+            return Err(Error::invalid_preamble_longs(
+                PREAMBLE_LONGS_SHORT,
+                preamble_longs,
+            ));
         }
 
         let num_buckets = 
cursor.read_u32::<LE>().map_err(make_error("num_buckets"))?;
@@ -273,9 +275,8 @@ impl CountMinSketch {
 
         let expected_seed_hash = compute_seed_hash(seed);
         if seed_hash != expected_seed_hash {
-            return Err(SerdeError::InvalidParameter(format!(
-                "incompatible seed hash: expected {}, got {}",
-                expected_seed_hash, seed_hash
+            return Err(Error::deserial(format!(
+                "incompatible seed hash: expected {expected_seed_hash}, got 
{seed_hash}",
             )));
         }
 
@@ -329,26 +330,19 @@ fn entries_for_config(num_hashes: u8, num_buckets: u32) 
-> usize {
     entries
 }
 
-fn entries_for_config_checked(num_hashes: u8, num_buckets: u32) -> 
Result<usize, SerdeError> {
+fn entries_for_config_checked(num_hashes: u8, num_buckets: u32) -> 
Result<usize, Error> {
     if num_hashes == 0 {
-        return Err(SerdeError::InvalidParameter(
-            "num_hashes must be at least 1".to_string(),
-        ));
+        return Err(Error::deserial("num_hashes must be at least 1"));
     }
     if num_buckets < 3 {
-        return Err(SerdeError::InvalidParameter(
-            "num_buckets must be at least 3".to_string(),
-        ));
+        return Err(Error::deserial("num_buckets must be at least 3"));
     }
     let entries = (num_hashes as usize)
         .checked_mul(num_buckets as usize)
-        .ok_or_else(|| {
-            SerdeError::InvalidParameter("num_hashes * num_buckets overflows 
usize".to_string())
-        })?;
+        .ok_or_else(|| Error::deserial("num_hashes * num_buckets overflows 
usize"))?;
     if entries >= MAX_TABLE_ENTRIES {
-        return Err(SerdeError::InvalidParameter(format!(
-            "num_hashes * num_buckets must be < {}",
-            MAX_TABLE_ENTRIES
+        return Err(Error::deserial(format!(
+            "num_hashes * num_buckets must be < {MAX_TABLE_ENTRIES}",
         )));
     }
     Ok(entries)
diff --git a/datasketches/src/error.rs b/datasketches/src/error.rs
index 88e71d9..624ee0a 100644
--- a/datasketches/src/error.rs
+++ b/datasketches/src/error.rs
@@ -19,31 +19,152 @@
 
 use std::fmt;
 
-/// Errors that can occur during sketch serialization or deserialization
-#[derive(Debug, Clone)]
-pub enum SerdeError {
-    /// Insufficient data in buffer
-    InsufficientData(String),
-    /// Invalid sketch family identifier
-    InvalidFamily(String),
-    /// Unsupported serialization version
-    UnsupportedVersion(String),
-    /// Invalid parameter value
-    InvalidParameter(String),
-    /// Malformed or corrupt sketch data
-    MalformedData(String),
-}
-
-impl fmt::Display for SerdeError {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+/// ErrorKind is all kinds of Error of datasketches.
+#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
+#[non_exhaustive]
+pub enum ErrorKind {
+    /// The argument provided is invalid.
+    InvalidArgument,
+    /// The sketch data deserializing is malformed.
+    InvalidData,
+}
+
+impl ErrorKind {
+    /// Convert this error kind instance into static str.
+    pub const fn into_static(self) -> &'static str {
         match self {
-            SerdeError::InsufficientData(msg) => write!(f, "insufficient data: 
{}", msg),
-            SerdeError::InvalidFamily(msg) => write!(f, "invalid family: {}", 
msg),
-            SerdeError::UnsupportedVersion(msg) => write!(f, "unsupported 
version: {}", msg),
-            SerdeError::InvalidParameter(msg) => write!(f, "invalid parameter: 
{}", msg),
-            SerdeError::MalformedData(msg) => write!(f, "malformed data: {}", 
msg),
+            ErrorKind::InvalidArgument => "InvalidArgument",
+            ErrorKind::InvalidData => "InvalidData",
+        }
+    }
+}
+
+impl fmt::Display for ErrorKind {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.into_static())
+    }
+}
+
+/// Error is the error struct returned by all datasketches functions.
+pub struct Error {
+    kind: ErrorKind,
+    message: String,
+    context: Vec<(&'static str, String)>,
+}
+
+impl Error {
+    /// Create a new Error with error kind and message.
+    pub fn new(kind: ErrorKind, message: impl Into<String>) -> Self {
+        Self {
+            kind,
+            message: message.into(),
+            context: vec![],
         }
     }
+
+    /// Add more context in error.
+    pub fn with_context(mut self, key: &'static str, value: impl ToString) -> 
Self {
+        self.context.push((key, value.to_string()));
+        self
+    }
+
+    /// Return error's kind.
+    pub fn kind(&self) -> ErrorKind {
+        self.kind
+    }
+
+    /// Return error's message.
+    pub fn message(&self) -> &str {
+        self.message.as_str()
+    }
+}
+
+// Convenience constructors for deserialization errors
+impl Error {
+    pub(crate) fn deserial(msg: impl Into<String>) -> Self {
+        Self::new(ErrorKind::InvalidData, msg)
+    }
+
+    pub(crate) fn insufficient_data(msg: impl fmt::Display) -> Self {
+        Self::deserial(format!("insufficient data: {msg}"))
+    }
+
+    pub(crate) fn insufficient_data_of(context: &'static str, msg: impl 
fmt::Display) -> Self {
+        Self::deserial(format!("insufficient data ({context}): {msg}"))
+    }
+
+    pub(crate) fn invalid_family(expected: u8, actual: u8, name: &'static str) 
-> Self {
+        Self::deserial(format!(
+            "invalid family: expected {expected} ({name}), got {actual}"
+        ))
+    }
+
+    pub(crate) fn unsupported_serial_version(expected: u8, actual: u8) -> Self 
{
+        Self::deserial(format!(
+            "unsupported serial version: expected {expected}, got {actual}"
+        ))
+    }
+
+    pub(crate) fn invalid_preamble_longs(expected: u8, actual: u8) -> Self {
+        Self::deserial(format!(
+            "invalid preamble longs: expected {expected}, got {actual}"
+        ))
+    }
+}
+
+impl fmt::Debug for Error {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        // If alternate has been specified, we will print like Debug.
+        if f.alternate() {
+            let mut de = f.debug_struct("Error");
+            de.field("kind", &self.kind);
+            de.field("message", &self.message);
+            de.field("context", &self.context);
+            return de.finish();
+        }
+
+        write!(f, "{}", self.kind)?;
+        if !self.message.is_empty() {
+            write!(f, " => {}", self.message)?;
+        }
+        writeln!(f)?;
+
+        if !self.context.is_empty() {
+            writeln!(f)?;
+            writeln!(f, "Context:")?;
+            for (k, v) in self.context.iter() {
+                writeln!(f, "   {k}: {v}")?;
+            }
+        }
+
+        Ok(())
+    }
+}
+
+impl fmt::Display for Error {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "{}", self.kind)?;
+
+        if !self.context.is_empty() {
+            write!(f, ", context: {{ ")?;
+            write!(
+                f,
+                "{}",
+                self.context
+                    .iter()
+                    .map(|(k, v)| format!("{k}: {v}"))
+                    .collect::<Vec<_>>()
+                    .join(", ")
+            )?;
+            write!(f, " }}")?;
+        }
+
+        if !self.message.is_empty() {
+            write!(f, " => {}", self.message)?;
+        }
+
+        Ok(())
+    }
 }
 
-impl std::error::Error for SerdeError {}
+impl std::error::Error for Error {}
diff --git a/datasketches/src/hll/array4.rs b/datasketches/src/hll/array4.rs
index 44707b6..fbef5e4 100644
--- a/datasketches/src/hll/array4.rs
+++ b/datasketches/src/hll/array4.rs
@@ -21,7 +21,7 @@
 //! When values exceed 4 bits after cur_min offset, they're stored in an 
auxiliary hash map.
 
 use super::aux_map::AuxMap;
-use crate::error::SerdeError;
+use crate::error::Error;
 use crate::hll::NumStdDev;
 use crate::hll::estimator::HipEstimator;
 use crate::hll::get_slot;
@@ -289,13 +289,13 @@ impl Array4 {
         lg_config_k: u8,
         compact: bool,
         ooo: bool,
-    ) -> Result<Self, SerdeError> {
+    ) -> Result<Self, Error> {
         use crate::hll::get_slot;
         use crate::hll::get_value;
         use crate::hll::serialization::*;
 
         if bytes.len() < HLL_PREAMBLE_SIZE {
-            return Err(SerdeError::InsufficientData(format!(
+            return Err(Error::insufficient_data(format!(
                 "expected at least {}, got {}",
                 HLL_PREAMBLE_SIZE,
                 bytes.len()
@@ -324,7 +324,7 @@ impl Array4 {
         };
 
         if bytes.len() < expected_len {
-            return Err(SerdeError::InsufficientData(format!(
+            return Err(Error::insufficient_data(format!(
                 "expected {}, got {}",
                 expected_len,
                 bytes.len()
@@ -392,7 +392,7 @@ impl Array4 {
 
         // Write standard header
         bytes[PREAMBLE_INTS_BYTE] = HLL_PREINTS;
-        bytes[SER_VER_BYTE] = SER_VER;
+        bytes[SER_VER_BYTE] = SERIAL_VER;
         bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
         bytes[LG_K_BYTE] = lg_config_k;
         bytes[LG_ARR_BYTE] = 0; // Not used for HLL mode
diff --git a/datasketches/src/hll/array6.rs b/datasketches/src/hll/array6.rs
index 8c7138b..5d36cb5 100644
--- a/datasketches/src/hll/array6.rs
+++ b/datasketches/src/hll/array6.rs
@@ -21,7 +21,7 @@
 //! This is sufficient for most HLL use cases without needing exception 
handling or
 //! cur_min optimization like Array4.
 
-use crate::error::SerdeError;
+use crate::error::Error;
 use crate::hll::NumStdDev;
 use crate::hll::estimator::HipEstimator;
 use crate::hll::get_slot;
@@ -173,7 +173,7 @@ impl Array6 {
         lg_config_k: u8,
         compact: bool,
         ooo: bool,
-    ) -> Result<Self, SerdeError> {
+    ) -> Result<Self, Error> {
         use crate::hll::serialization::*;
 
         let k = 1 << lg_config_k;
@@ -185,7 +185,7 @@ impl Array6 {
         };
 
         if bytes.len() < expected_len {
-            return Err(SerdeError::InsufficientData(format!(
+            return Err(Error::insufficient_data(format!(
                 "expected {}, got {}",
                 expected_len,
                 bytes.len()
@@ -234,7 +234,7 @@ impl Array6 {
 
         // Write standard header
         bytes[PREAMBLE_INTS_BYTE] = HLL_PREINTS;
-        bytes[SER_VER_BYTE] = SER_VER;
+        bytes[SER_VER_BYTE] = SERIAL_VER;
         bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
         bytes[LG_K_BYTE] = lg_config_k;
         bytes[LG_ARR_BYTE] = 0; // Not used for HLL mode
diff --git a/datasketches/src/hll/array8.rs b/datasketches/src/hll/array8.rs
index 33e2cca..dd3b556 100644
--- a/datasketches/src/hll/array8.rs
+++ b/datasketches/src/hll/array8.rs
@@ -20,7 +20,7 @@
 //! Array8 is the simplest HLL array implementation, storing one byte per slot.
 //! This provides the maximum value range (0-255) with no bit-packing 
complexity.
 
-use crate::error::SerdeError;
+use crate::error::Error;
 use crate::hll::NumStdDev;
 use crate::hll::estimator::HipEstimator;
 use crate::hll::get_slot;
@@ -247,7 +247,7 @@ impl Array8 {
         lg_config_k: u8,
         compact: bool,
         ooo: bool,
-    ) -> Result<Self, SerdeError> {
+    ) -> Result<Self, Error> {
         use crate::hll::serialization::*;
 
         let k = 1 << lg_config_k;
@@ -258,7 +258,7 @@ impl Array8 {
         };
 
         if bytes.len() < expected_len {
-            return Err(SerdeError::InsufficientData(format!(
+            return Err(Error::insufficient_data(format!(
                 "expected {}, got {}",
                 expected_len,
                 bytes.len()
@@ -306,7 +306,7 @@ impl Array8 {
 
         // Write standard header
         bytes[PREAMBLE_INTS_BYTE] = HLL_PREINTS;
-        bytes[SER_VER_BYTE] = SER_VER;
+        bytes[SER_VER_BYTE] = SERIAL_VER;
         bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
         bytes[LG_K_BYTE] = lg_config_k;
         bytes[LG_ARR_BYTE] = 0; // Not used for HLL mode
diff --git a/datasketches/src/hll/hash_set.rs b/datasketches/src/hll/hash_set.rs
index 05f5ad2..f6bff05 100644
--- a/datasketches/src/hll/hash_set.rs
+++ b/datasketches/src/hll/hash_set.rs
@@ -20,7 +20,7 @@
 //! Uses open addressing with a custom stride function to handle collisions.
 //! Provides better performance than List when many coupons are stored.
 
-use crate::error::SerdeError;
+use crate::error::Error;
 use crate::hll::HllType;
 use crate::hll::KEY_MASK_26;
 use crate::hll::container::COUPON_EMPTY;
@@ -84,7 +84,7 @@ impl HashSet {
     }
 
     /// Deserialize a HashSet from bytes
-    pub fn deserialize(bytes: &[u8], compact: bool) -> Result<Self, 
SerdeError> {
+    pub fn deserialize(bytes: &[u8], compact: bool) -> Result<Self, Error> {
         // Read coupon count from bytes 8-11
         let coupon_count = read_u32_le(bytes, HASH_SET_COUNT_INT) as usize;
 
@@ -95,7 +95,7 @@ impl HashSet {
             // Compact mode: only couponCount coupons are stored
             let expected_len = HASH_SET_INT_ARR_START + (coupon_count * 4);
             if bytes.len() < expected_len {
-                return Err(SerdeError::InsufficientData(format!(
+                return Err(Error::insufficient_data(format!(
                     "expected {}, got {}",
                     expected_len,
                     bytes.len()
@@ -115,7 +115,7 @@ impl HashSet {
             let array_size = 1 << lg_arr;
             let expected_len = HASH_SET_INT_ARR_START + (array_size * 4);
             if bytes.len() < expected_len {
-                return Err(SerdeError::InsufficientData(format!(
+                return Err(Error::insufficient_data(format!(
                     "expected {}, got {}",
                     expected_len,
                     bytes.len()
@@ -153,7 +153,7 @@ impl HashSet {
 
         // Write preamble
         bytes[PREAMBLE_INTS_BYTE] = HASH_SET_PREINTS;
-        bytes[SER_VER_BYTE] = SER_VER;
+        bytes[SER_VER_BYTE] = SERIAL_VER;
         bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
         bytes[LG_K_BYTE] = lg_config_k;
         bytes[LG_ARR_BYTE] = lg_arr as u8;
diff --git a/datasketches/src/hll/list.rs b/datasketches/src/hll/list.rs
index d01a9b7..c705383 100644
--- a/datasketches/src/hll/list.rs
+++ b/datasketches/src/hll/list.rs
@@ -20,7 +20,7 @@
 //! Provides sequential storage with linear search for duplicates.
 //! Efficient for small numbers of coupons before transitioning to HashSet.
 
-use crate::error::SerdeError;
+use crate::error::Error;
 use crate::hll::HllType;
 use crate::hll::container::COUPON_EMPTY;
 use crate::hll::container::Container;
@@ -66,7 +66,7 @@ impl List {
     }
 
     /// Deserialize a List from bytes
-    pub fn deserialize(bytes: &[u8], empty: bool, compact: bool) -> 
Result<Self, SerdeError> {
+    pub fn deserialize(bytes: &[u8], empty: bool, compact: bool) -> 
Result<Self, Error> {
         // Read coupon count from byte 6
         let coupon_count = bytes[LIST_COUNT_BYTE] as usize;
 
@@ -77,7 +77,7 @@ impl List {
         // Validate length
         let expected_len = LIST_INT_ARR_START + (array_size * 4);
         if bytes.len() < expected_len {
-            return Err(SerdeError::InsufficientData(format!(
+            return Err(Error::insufficient_data(format!(
                 "expected {}, got {}",
                 expected_len,
                 bytes.len()
@@ -113,7 +113,7 @@ impl List {
 
         // Write preamble
         bytes[PREAMBLE_INTS_BYTE] = LIST_PREINTS;
-        bytes[SER_VER_BYTE] = SER_VER;
+        bytes[SER_VER_BYTE] = SERIAL_VER;
         bytes[FAMILY_BYTE] = HLL_FAMILY_ID;
         bytes[LG_K_BYTE] = lg_config_k;
         bytes[LG_ARR_BYTE] = lg_arr as u8;
diff --git a/datasketches/src/hll/serialization.rs 
b/datasketches/src/hll/serialization.rs
index e111393..b99a262 100644
--- a/datasketches/src/hll/serialization.rs
+++ b/datasketches/src/hll/serialization.rs
@@ -24,7 +24,7 @@
 pub const HLL_FAMILY_ID: u8 = 7;
 
 /// Current serialization version
-pub const SER_VER: u8 = 1;
+pub const SERIAL_VER: u8 = 1;
 
 /// Flag indicating sketch is empty (no values inserted)
 pub const EMPTY_FLAG_MASK: u8 = 4;
diff --git a/datasketches/src/hll/sketch.rs b/datasketches/src/hll/sketch.rs
index 6887338..a5f0aac 100644
--- a/datasketches/src/hll/sketch.rs
+++ b/datasketches/src/hll/sketch.rs
@@ -22,7 +22,7 @@
 
 use std::hash::Hash;
 
-use crate::error::SerdeError;
+use crate::error::Error;
 use crate::hll::HllType;
 use crate::hll::NumStdDev;
 use crate::hll::RESIZE_DENOMINATOR;
@@ -212,16 +212,16 @@ impl HllSketch {
     }
 
     /// Deserializes an HLL sketch from bytes
-    pub fn deserialize(bytes: &[u8]) -> Result<HllSketch, SerdeError> {
+    pub fn deserialize(bytes: &[u8]) -> Result<HllSketch, Error> {
         if bytes.len() < 8 {
-            return Err(SerdeError::InsufficientData(
-                "sketch data too short (< 8 bytes)".to_string(),
+            return Err(Error::insufficient_data(
+                "sketch data too short (< 8 bytes)",
             ));
         }
 
         // Read and validate preamble
         let preamble_ints = bytes[PREAMBLE_INTS_BYTE];
-        let ser_ver = bytes[SER_VER_BYTE];
+        let serial_ver = bytes[SER_VER_BYTE];
         let family_id = bytes[FAMILY_BYTE];
         let lg_config_k = bytes[LG_K_BYTE];
         let flags = bytes[FLAGS_BYTE];
@@ -229,25 +229,18 @@ impl HllSketch {
 
         // Verify family ID
         if family_id != HLL_FAMILY_ID {
-            return Err(SerdeError::InvalidFamily(format!(
-                "expected {} (HLL), got {}",
-                HLL_FAMILY_ID, family_id
-            )));
+            return Err(Error::invalid_family(HLL_FAMILY_ID, family_id, "HLL"));
         }
 
         // Verify serialization version
-        if ser_ver != SER_VER {
-            return Err(SerdeError::UnsupportedVersion(format!(
-                "expected {}, got {}",
-                SER_VER, ser_ver
-            )));
+        if serial_ver != SERIAL_VER {
+            return Err(Error::unsupported_serial_version(SERIAL_VER, 
serial_ver));
         }
 
         // Verify lg_k range (4-21 are valid)
         if !(4..=21).contains(&lg_config_k) {
-            return Err(SerdeError::InvalidParameter(format!(
-                "lg_k must be in [4; 21], got {}",
-                lg_config_k
+            return Err(Error::deserial(format!(
+                "lg_k must be in [4; 21], got {lg_config_k}",
             )));
         }
 
@@ -256,10 +249,7 @@ impl HllSketch {
             TGT_HLL6 => HllType::Hll6,
             TGT_HLL8 => HllType::Hll8,
             hll_type => {
-                return Err(SerdeError::MalformedData(format!(
-                    "invalid HLL type: {}",
-                    hll_type
-                )));
+                return Err(Error::deserial(format!("invalid HLL type: 
{hll_type}")));
             }
         };
 
@@ -272,9 +262,9 @@ impl HllSketch {
             match extract_cur_mode(mode_byte) {
                 CUR_MODE_LIST => {
                     if preamble_ints != LIST_PREINTS {
-                        return Err(SerdeError::MalformedData(format!(
+                        return Err(Error::deserial(format!(
                             "LIST mode preamble: expected {}, got {}",
-                            LIST_PREINTS, preamble_ints
+                            LIST_PREINTS, preamble_ints,
                         )));
                     }
 
@@ -283,7 +273,7 @@ impl HllSketch {
                 }
                 CUR_MODE_SET => {
                     if preamble_ints != HASH_SET_PREINTS {
-                        return Err(SerdeError::MalformedData(format!(
+                        return Err(Error::deserial(format!(
                             "SET mode preamble: expected {}, got {}",
                             HASH_SET_PREINTS, preamble_ints
                         )));
@@ -294,7 +284,7 @@ impl HllSketch {
                 }
                 CUR_MODE_HLL => {
                     if preamble_ints != HLL_PREINTS {
-                        return Err(SerdeError::MalformedData(format!(
+                        return Err(Error::deserial(format!(
                             "HLL mode preamble: expected {}, got {}",
                             HLL_PREINTS, preamble_ints
                         )));
@@ -309,7 +299,7 @@ impl HllSketch {
                             .map(Mode::Array8)?,
                     }
                 }
-                mode => return Err(SerdeError::MalformedData(format!("invalid 
mode: {}", mode))),
+                mode => return Err(Error::deserial(format!("invalid mode: 
{mode}"))),
             };
 
         Ok(HllSketch { lg_config_k, mode })
diff --git a/datasketches/src/tdigest/sketch.rs 
b/datasketches/src/tdigest/sketch.rs
index 13aa6ca..a0f3883 100644
--- a/datasketches/src/tdigest/sketch.rs
+++ b/datasketches/src/tdigest/sketch.rs
@@ -24,7 +24,8 @@ use byteorder::BE;
 use byteorder::LE;
 use byteorder::ReadBytesExt;
 
-use crate::error::SerdeError;
+use crate::error::Error;
+use crate::error::ErrorKind;
 use crate::tdigest::serialization::*;
 
 /// The default value of K if one is not specified.
@@ -60,9 +61,11 @@ impl Default for TDigestMut {
 impl TDigestMut {
     /// Creates a tdigest instance with the given value of k.
     ///
+    /// The fallible version of this method is [`TDigestMut::try_new`].
+    ///
     /// # Panics
     ///
-    /// If k is less than 10
+    /// Panics if k is less than 10
     pub fn new(k: u16) -> Self {
         Self::make(
             k,
@@ -75,6 +78,32 @@ impl TDigestMut {
         )
     }
 
+    /// Creates a tdigest instance with the given value of k.
+    ///
+    /// The panicking version of this method is [`TDigestMut::new`].
+    ///
+    /// # Errors
+    ///
+    /// If k is less than 10, returns [`ErrorKind::InvalidArgument`].
+    pub fn try_new(k: u16) -> Result<Self, Error> {
+        if k < 10 {
+            return Err(Error::new(
+                ErrorKind::InvalidArgument,
+                format!("k must be at least 10, got {k}"),
+            ));
+        }
+
+        Ok(Self::make(
+            k,
+            false,
+            f64::INFINITY,
+            f64::NEG_INFINITY,
+            vec![],
+            0,
+            vec![],
+        ))
+    }
+
     // for deserialization
     fn make(
         k: u16,
@@ -205,27 +234,7 @@ impl TDigestMut {
         }
     }
 
-    /// Returns an approximation to the Cumulative Distribution Function 
(CDF), which is the
-    /// cumulative analog of the PMF, of the input stream given a set of split 
points.
-    ///
-    /// # Arguments
-    ///
-    /// * `split_points`: An array of _m_ unique, monotonically increasing 
values that divide the
-    ///   input domain into _m+1_ consecutive disjoint intervals.
-    ///
-    /// # Returns
-    ///
-    /// An array of m+1 doubles, which are a consecutive approximation to the 
CDF of the input
-    /// stream given the split points. The value at array position j of the 
returned CDF array
-    /// is the sum of the returned values in positions 0 through j of the 
returned PMF array.
-    /// This can be viewed as array of ranks of the given split points plus 
one more value that
-    /// is always 1.
-    ///
-    /// Returns `None` if TDigest is empty.
-    ///
-    /// # Panics
-    ///
-    /// If `split_points` is not unique, not monotonically increasing, or 
contains `NaN` values.
+    /// See [`TDigest::cdf`].
     pub fn cdf(&mut self, split_points: &[f64]) -> Option<Vec<f64>> {
         check_split_points(split_points);
 
@@ -236,24 +245,7 @@ impl TDigestMut {
         self.view().cdf(split_points)
     }
 
-    /// Returns an approximation to the Probability Mass Function (PMF) of the 
input stream
-    /// given a set of split points.
-    ///
-    /// # Arguments
-    ///
-    /// * `split_points`: An array of _m_ unique, monotonically increasing 
values that divide the
-    ///   input domain into _m+1_ consecutive disjoint intervals (bins).
-    ///
-    /// # Returns
-    ///
-    /// An array of m+1 doubles each of which is an approximation to the 
fraction of the input
-    /// stream values (the mass) that fall into one of those intervals.
-    ///
-    /// Returns `None` if TDigest is empty.
-    ///
-    /// # Panics
-    ///
-    /// If `split_points` is not unique, not monotonically increasing, or 
contains `NaN` values.
+    /// See [`TDigest::pmf`].
     pub fn pmf(&mut self, split_points: &[f64]) -> Option<Vec<f64>> {
         check_split_points(split_points);
 
@@ -264,13 +256,7 @@ impl TDigestMut {
         self.view().pmf(split_points)
     }
 
-    /// Compute approximate normalized rank (from 0 to 1 inclusive) of the 
given value.
-    ///
-    /// Returns `None` if TDigest is empty.
-    ///
-    /// # Panics
-    ///
-    /// If the value is `NaN`.
+    /// See [`TDigest::rank`].
     pub fn rank(&mut self, value: f64) -> Option<f64> {
         assert!(!value.is_nan(), "value must not be NaN");
 
@@ -291,13 +277,7 @@ impl TDigestMut {
         self.view().rank(value)
     }
 
-    /// Compute approximate quantile value corresponding to the given 
normalized rank.
-    ///
-    /// Returns `None` if TDigest is empty.
-    ///
-    /// # Panics
-    ///
-    /// If rank is not in [0.0, 1.0].
+    /// See [`TDigest::quantile`].
     pub fn quantile(&mut self, rank: f64) -> Option<f64> {
         assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]");
 
@@ -390,9 +370,9 @@ impl TDigestMut {
     ///
     /// [^1]: This is to support reading the `tdigest<float>` format from the 
C++ implementation.
     /// [^2]: <https://github.com/tdunning/t-digest>
-    pub fn deserialize(bytes: &[u8], is_f32: bool) -> Result<Self, SerdeError> 
{
-        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
SerdeError {
-            move |_| SerdeError::InsufficientData(tag.to_string())
+    pub fn deserialize(bytes: &[u8], is_f32: bool) -> Result<Self, Error> {
+        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
Error {
+            move |_| Error::insufficient_data(tag)
         }
 
         let mut cursor = Cursor::new(bytes);
@@ -401,25 +381,25 @@ impl TDigestMut {
         let serial_version = 
cursor.read_u8().map_err(make_error("serial_version"))?;
         let family_id = cursor.read_u8().map_err(make_error("family_id"))?;
         if family_id != TDIGEST_FAMILY_ID {
-            if preamble_longs == 0 && serial_version == 0 && family_id == 0 {
-                return Self::deserialize_compat(bytes);
-            }
-            return Err(SerdeError::InvalidFamily(format!(
-                "expected {} (TDigest), got {}",
-                TDIGEST_FAMILY_ID, family_id
-            )));
+            return if preamble_longs == 0 && serial_version == 0 && family_id 
== 0 {
+                Self::deserialize_compat(bytes)
+            } else {
+                Err(Error::invalid_family(
+                    TDIGEST_FAMILY_ID,
+                    family_id,
+                    "TDigest",
+                ))
+            };
         }
         if serial_version != SERIAL_VERSION {
-            return Err(SerdeError::UnsupportedVersion(format!(
-                "expected {}, got {}",
-                SERIAL_VERSION, serial_version
-            )));
+            return Err(Error::unsupported_serial_version(
+                SERIAL_VERSION,
+                serial_version,
+            ));
         }
         let k = cursor.read_u16::<LE>().map_err(make_error("k"))?;
         if k < 10 {
-            return Err(SerdeError::InvalidParameter(format!(
-                "k must be at least 10, got {k}"
-            )));
+            return Err(Error::deserial(format!("k must be at least 10, got 
{k}")));
         }
         let flags = cursor.read_u8().map_err(make_error("flags"))?;
         let is_empty = (flags & FLAGS_IS_EMPTY) != 0;
@@ -430,10 +410,10 @@ impl TDigestMut {
             PREAMBLE_LONGS_MULTIPLE
         };
         if preamble_longs != expected_preamble_longs {
-            return Err(SerdeError::MalformedData(format!(
-                "expected preamble_longs to be {}, got {}",
-                expected_preamble_longs, preamble_longs
-            )));
+            return Err(Error::invalid_preamble_longs(
+                expected_preamble_longs,
+                preamble_longs,
+            ));
         }
         cursor.read_u16::<LE>().map_err(make_error("<unused>"))?; // unused
         if is_empty {
@@ -452,7 +432,7 @@ impl TDigestMut {
                     .map_err(make_error("single_value"))?
             };
             check_non_nan(value, "single_value")?;
-            check_non_infinite(value, "single_value")?;
+            check_finite(value, "single_value")?;
             return Ok(TDigestMut::make(
                 k,
                 reverse_merge,
@@ -500,7 +480,7 @@ impl TDigestMut {
                 )
             };
             check_non_nan(mean, "centroid mean")?;
-            check_non_infinite(mean, "centroid")?;
+            check_finite(mean, "centroid")?;
             let weight = check_nonzero(weight, "centroid weight")?;
             centroids_weight += weight.get();
             centroids.push(Centroid { mean, weight });
@@ -517,7 +497,7 @@ impl TDigestMut {
                     .map_err(make_error("buffered_value"))?
             };
             check_non_nan(value, "buffered_value mean")?;
-            check_non_infinite(value, "buffered_value mean")?;
+            check_finite(value, "buffered_value mean")?;
             buffer.push(value);
         }
         Ok(TDigestMut::make(
@@ -533,9 +513,9 @@ impl TDigestMut {
 
     // compatibility with the format of the reference implementation
     // default byte order of ByteBuffer is used there, which is big endian
-    fn deserialize_compat(bytes: &[u8]) -> Result<Self, SerdeError> {
-        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
SerdeError {
-            move |_| SerdeError::InsufficientData(format!("{tag} in compat 
format"))
+    fn deserialize_compat(bytes: &[u8]) -> Result<Self, Error> {
+        fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> 
Error {
+            move |_| Error::insufficient_data_of("compat format", tag)
         }
 
         let mut cursor = Cursor::new(bytes);
@@ -543,8 +523,8 @@ impl TDigestMut {
         let ty = cursor.read_u32::<BE>().map_err(make_error("type"))?;
         match ty {
             COMPAT_DOUBLE => {
-                fn make_error(tag: &'static str) -> impl 
FnOnce(std::io::Error) -> SerdeError {
-                    move |_| SerdeError::InsufficientData(format!("{tag} in 
compat double format"))
+                fn make_error(tag: &'static str) -> impl 
FnOnce(std::io::Error) -> Error {
+                    move |_| Error::insufficient_data_of("compat double 
format", tag)
                 }
                 // compatibility with asBytes()
                 let min = cursor.read_f64::<BE>().map_err(make_error("min"))?;
@@ -553,8 +533,8 @@ impl TDigestMut {
                 check_non_nan(max, "max in compat double format")?;
                 let k = cursor.read_f64::<BE>().map_err(make_error("k"))? as 
u16;
                 if k < 10 {
-                    return Err(SerdeError::InvalidParameter(format!(
-                        "k must be at least 10, got {k} in compat double 
format"
+                    return Err(Error::deserial(format!(
+                        "k must be at least 10 in compat double format, got 
{k}"
                     )));
                 }
                 let num_centroids = cursor
@@ -568,7 +548,7 @@ impl TDigestMut {
                     let mean = 
cursor.read_f64::<BE>().map_err(make_error("mean"))?;
                     let weight = check_nonzero(weight, "centroid weight in 
compat double format")?;
                     check_non_nan(mean, "centroid mean in compat double 
format")?;
-                    check_non_infinite(mean, "centroid mean in compat double 
format")?;
+                    check_finite(mean, "centroid mean in compat double 
format")?;
                     total_weight += weight.get();
                     centroids.push(Centroid { mean, weight });
                 }
@@ -583,8 +563,8 @@ impl TDigestMut {
                 ))
             }
             COMPAT_FLOAT => {
-                fn make_error(tag: &'static str) -> impl 
FnOnce(std::io::Error) -> SerdeError {
-                    move |_| SerdeError::InsufficientData(format!("{tag} in 
compat float format"))
+                fn make_error(tag: &'static str) -> impl 
FnOnce(std::io::Error) -> Error {
+                    move |_| Error::insufficient_data_of("compat float 
format", tag)
                 }
                 // COMPAT_FLOAT: compatibility with asSmallBytes()
                 // reference implementation uses doubles for min and max
@@ -594,8 +574,8 @@ impl TDigestMut {
                 check_non_nan(max, "max in compat float format")?;
                 let k = cursor.read_f32::<BE>().map_err(make_error("k"))? as 
u16;
                 if k < 10 {
-                    return Err(SerdeError::InvalidParameter(format!(
-                        "k must be at least 10, got {k} in compat float format"
+                    return Err(Error::deserial(format!(
+                        "k must be at least 10 in compat float format, got {k}"
                     )));
                 }
                 // reference implementation stores capacities of the array of 
centroids and the
@@ -612,7 +592,7 @@ impl TDigestMut {
                     let mean = 
cursor.read_f32::<BE>().map_err(make_error("mean"))? as f64;
                     let weight = check_nonzero(weight, "centroid weight in 
compat float format")?;
                     check_non_nan(mean, "centroid mean in compat float 
format")?;
-                    check_non_infinite(mean, "centroid mean in compat float 
format")?;
+                    check_finite(mean, "centroid mean in compat float 
format")?;
                     total_weight += weight.get();
                     centroids.push(Centroid { mean, weight });
                 }
@@ -626,9 +606,7 @@ impl TDigestMut {
                     vec![],
                 ))
             }
-            ty => Err(SerdeError::InvalidParameter(format!(
-                "unknown TDigest compat type {ty}",
-            ))),
+            ty => Err(Error::deserial(format!("unknown TDigest compat type 
{ty}"))),
         }
     }
 
@@ -786,7 +764,8 @@ impl TDigest {
     ///
     /// # Panics
     ///
-    /// If `split_points` is not unique, not monotonically increasing, or 
contains `NaN` values.
+    /// Panics if `split_points` is not unique, not monotonically increasing, 
or contains `NaN`
+    /// values.
     pub fn cdf(&self, split_points: &[f64]) -> Option<Vec<f64>> {
         self.view().cdf(split_points)
     }
@@ -808,7 +787,8 @@ impl TDigest {
     ///
     /// # Panics
     ///
-    /// If `split_points` is not unique, not monotonically increasing, or 
contains `NaN` values.
+    /// Panics if `split_points` is not unique, not monotonically increasing, 
or contains `NaN`
+    /// values.
     pub fn pmf(&self, split_points: &[f64]) -> Option<Vec<f64>> {
         self.view().pmf(split_points)
     }
@@ -819,7 +799,7 @@ impl TDigest {
     ///
     /// # Panics
     ///
-    /// If the value is `NaN`.
+    /// Panics if the value is `NaN`.
     pub fn rank(&self, value: f64) -> Option<f64> {
         assert!(!value.is_nan(), "value must not be NaN");
         self.view().rank(value)
@@ -831,7 +811,7 @@ impl TDigest {
     ///
     /// # Panics
     ///
-    /// If rank is not in [0.0, 1.0].
+    /// Panics if rank is not in [0.0, 1.0].
     pub fn quantile(&self, rank: f64) -> Option<f64> {
         assert!((0.0..=1.0).contains(&rank), "rank must be in [0.0, 1.0]");
         self.view().quantile(rank)
@@ -1129,24 +1109,29 @@ impl Centroid {
     }
 }
 
-fn check_non_nan(value: f64, tag: &'static str) -> Result<(), SerdeError> {
+fn check_non_nan(value: f64, tag: &'static str) -> Result<(), Error> {
     if value.is_nan() {
-        return Err(SerdeError::MalformedData(format!("{tag} cannot be NaN")));
+        return Err(Error::deserial(format!(
+            "malformed data: {tag} cannot be NaN"
+        )));
     }
+
     Ok(())
 }
 
-fn check_non_infinite(value: f64, tag: &'static str) -> Result<(), SerdeError> 
{
+fn check_finite(value: f64, tag: &'static str) -> Result<(), Error> {
     if value.is_infinite() {
-        return Err(SerdeError::MalformedData(format!(
-            "{tag} cannot be is_infinite"
+        return Err(Error::deserial(format!(
+            "malformed data: {tag} cannot be infinite"
         )));
     }
+
     Ok(())
 }
 
-fn check_nonzero(value: u64, tag: &'static str) -> Result<NonZeroU64, 
SerdeError> {
-    NonZeroU64::new(value).ok_or_else(|| 
SerdeError::MalformedData(format!("{tag} cannot be zero")))
+fn check_nonzero(value: u64, tag: &'static str) -> Result<NonZeroU64, Error> {
+    NonZeroU64::new(value)
+        .ok_or_else(|| Error::deserial(format!("malformed data: {tag} cannot 
be zero")))
 }
 
 /// Generates cluster sizes proportional to `q*(1-q)`.


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to