This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 959577deab [Variant] impl [Try]From for VariantDecimalXX types (#7809)
959577deab is described below

commit 959577deabf6e27524cc7624e45e58bc2723f478
Author: Ryan Johnson <[email protected]>
AuthorDate: Tue Jul 1 04:06:01 2025 -0700

    [Variant] impl [Try]From for VariantDecimalXX types (#7809)
    
    # Which issue does this PR close?
    
    - Part of https://github.com/apache/arrow-rs/issues/6736
    
    # Rationale for this change
    
    The existing `Variant::as_decimal_XX` methods were actually incorrect,
    failing to validate scale when converting from a wider decimal to a
    narrower one. Fix it, while also improving ergonomics of the decimal
    code to reduce the chances of future issues of this type.
    
    # What changes are included in this PR?
    
    Add proper [Try]From for converting to decimal types from other decimals
    or their underlying integer type.
    
    Add missing conversions in the `Variant::as_int_xx` and
    `Variant::as_decimal_xx` helpers.
    
    # Are these changes tested?
    
    TODO - need more tests for the new conversions
    
    # Are there any user-facing changes?
    
    The `Variant:as_decimal_xx` methods have been renamed and now return
    actual decimal types.
    
    New conversions available.
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 parquet-variant/src/variant.rs         | 103 ++++++++++---------
 parquet-variant/src/variant/decimal.rs | 178 +++++++++++++++++++--------------
 2 files changed, 157 insertions(+), 124 deletions(-)

diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs
index 3dcb08053a..36564c2bff 100644
--- a/parquet-variant/src/variant.rs
+++ b/parquet-variant/src/variant.rs
@@ -538,6 +538,9 @@ impl<'m, 'v> Variant<'m, 'v> {
             Variant::Int16(i) => i.try_into().ok(),
             Variant::Int32(i) => i.try_into().ok(),
             Variant::Int64(i) => i.try_into().ok(),
+            Variant::Decimal4(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
+            Variant::Decimal8(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
+            Variant::Decimal16(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
             _ => None,
         }
     }
@@ -570,6 +573,9 @@ impl<'m, 'v> Variant<'m, 'v> {
             Variant::Int16(i) => Some(i),
             Variant::Int32(i) => i.try_into().ok(),
             Variant::Int64(i) => i.try_into().ok(),
+            Variant::Decimal4(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
+            Variant::Decimal8(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
+            Variant::Decimal16(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
             _ => None,
         }
     }
@@ -602,6 +608,9 @@ impl<'m, 'v> Variant<'m, 'v> {
             Variant::Int16(i) => Some(i.into()),
             Variant::Int32(i) => Some(i),
             Variant::Int64(i) => i.try_into().ok(),
+            Variant::Decimal4(d) if d.scale() == 0 => Some(d.integer()),
+            Variant::Decimal8(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
+            Variant::Decimal16(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
             _ => None,
         }
     }
@@ -630,6 +639,9 @@ impl<'m, 'v> Variant<'m, 'v> {
             Variant::Int16(i) => Some(i.into()),
             Variant::Int32(i) => Some(i.into()),
             Variant::Int64(i) => Some(i),
+            Variant::Decimal4(d) if d.scale() == 0 => Some(d.integer().into()),
+            Variant::Decimal8(d) if d.scale() == 0 => Some(d.integer()),
+            Variant::Decimal16(d) if d.scale() == 0 => 
d.integer().try_into().ok(),
             _ => None,
         }
     }
@@ -647,37 +659,29 @@ impl<'m, 'v> Variant<'m, 'v> {
     ///
     /// // you can extract decimal parts from smaller or equally-sized decimal 
variants
     /// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap());
-    /// assert_eq!(v1.as_decimal_int32(), Some((1234_i32, 2)));
+    /// assert_eq!(v1.as_decimal4(), VariantDecimal4::try_new(1234_i32, 
2).ok());
     ///
     /// // and from larger decimal variants if they fit
     /// let v2 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap());
-    /// assert_eq!(v2.as_decimal_int32(), Some((1234_i32, 2)));
+    /// assert_eq!(v2.as_decimal4(), VariantDecimal4::try_new(1234_i32, 
2).ok());
     ///
     /// // but not if the value would overflow i32
     /// let v3 = Variant::from(VariantDecimal8::try_new(12345678901i64, 
2).unwrap());
-    /// assert_eq!(v3.as_decimal_int32(), None);
+    /// assert_eq!(v3.as_decimal4(), None);
     ///
     /// // or if the variant is not a decimal
     /// let v4 = Variant::from("hello!");
-    /// assert_eq!(v4.as_decimal_int32(), None);
+    /// assert_eq!(v4.as_decimal4(), None);
     /// ```
-    pub fn as_decimal_int32(&self) -> Option<(i32, u8)> {
+    pub fn as_decimal4(&self) -> Option<VariantDecimal4> {
         match *self {
-            Variant::Decimal4(decimal4) => Some((decimal4.integer(), 
decimal4.scale())),
-            Variant::Decimal8(decimal8) => {
-                if let Ok(converted_integer) = decimal8.integer().try_into() {
-                    Some((converted_integer, decimal8.scale()))
-                } else {
-                    None
-                }
-            }
-            Variant::Decimal16(decimal16) => {
-                if let Ok(converted_integer) = decimal16.integer().try_into() {
-                    Some((converted_integer, decimal16.scale()))
-                } else {
-                    None
-                }
-            }
+            Variant::Int8(i) => i32::from(i).try_into().ok(),
+            Variant::Int16(i) => i32::from(i).try_into().ok(),
+            Variant::Int32(i) => i.try_into().ok(),
+            Variant::Int64(i) => i32::try_from(i).ok()?.try_into().ok(),
+            Variant::Decimal4(decimal4) => Some(decimal4),
+            Variant::Decimal8(decimal8) => decimal8.try_into().ok(),
+            Variant::Decimal16(decimal16) => decimal16.try_into().ok(),
             _ => None,
         }
     }
@@ -691,35 +695,33 @@ impl<'m, 'v> Variant<'m, 'v> {
     /// # Examples
     ///
     /// ```
-    /// use parquet_variant::{Variant, VariantDecimal8, VariantDecimal16};
+    /// use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8, 
VariantDecimal16};
     ///
     /// // you can extract decimal parts from smaller or equally-sized decimal 
variants
-    /// let v1 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap());
-    /// assert_eq!(v1.as_decimal_int64(), Some((1234_i64, 2)));
+    /// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap());
+    /// assert_eq!(v1.as_decimal8(), VariantDecimal8::try_new(1234_i64, 
2).ok());
     ///
     /// // and from larger decimal variants if they fit
     /// let v2 = Variant::from(VariantDecimal16::try_new(1234_i128, 
2).unwrap());
-    /// assert_eq!(v2.as_decimal_int64(), Some((1234_i64, 2)));
+    /// assert_eq!(v2.as_decimal8(), VariantDecimal8::try_new(1234_i64, 
2).ok());
     ///
     /// // but not if the value would overflow i64
     /// let v3 = Variant::from(VariantDecimal16::try_new(2e19 as i128, 
2).unwrap());
-    /// assert_eq!(v3.as_decimal_int64(), None);
+    /// assert_eq!(v3.as_decimal8(), None);
     ///
     /// // or if the variant is not a decimal
     /// let v4 = Variant::from("hello!");
-    /// assert_eq!(v4.as_decimal_int64(), None);
+    /// assert_eq!(v4.as_decimal8(), None);
     /// ```
-    pub fn as_decimal_int64(&self) -> Option<(i64, u8)> {
+    pub fn as_decimal8(&self) -> Option<VariantDecimal8> {
         match *self {
-            Variant::Decimal4(decimal) => Some((decimal.integer().into(), 
decimal.scale())),
-            Variant::Decimal8(decimal) => Some((decimal.integer(), 
decimal.scale())),
-            Variant::Decimal16(decimal) => {
-                if let Ok(converted_integer) = decimal.integer().try_into() {
-                    Some((converted_integer, decimal.scale()))
-                } else {
-                    None
-                }
-            }
+            Variant::Int8(i) => i64::from(i).try_into().ok(),
+            Variant::Int16(i) => i64::from(i).try_into().ok(),
+            Variant::Int32(i) => i64::from(i).try_into().ok(),
+            Variant::Int64(i) => i.try_into().ok(),
+            Variant::Decimal4(decimal4) => Some(decimal4.into()),
+            Variant::Decimal8(decimal8) => Some(decimal8),
+            Variant::Decimal16(decimal16) => decimal16.try_into().ok(),
             _ => None,
         }
     }
@@ -733,21 +735,25 @@ impl<'m, 'v> Variant<'m, 'v> {
     /// # Examples
     ///
     /// ```
-    /// use parquet_variant::{Variant, VariantDecimal16};
+    /// use parquet_variant::{Variant, VariantDecimal16, VariantDecimal4};
     ///
     /// // you can extract decimal parts from smaller or equally-sized decimal 
variants
-    /// let v1 = Variant::from(VariantDecimal16::try_new(1234_i128, 
2).unwrap());
-    /// assert_eq!(v1.as_decimal_int128(), Some((1234_i128, 2)));
+    /// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap());
+    /// assert_eq!(v1.as_decimal16(), VariantDecimal16::try_new(1234_i128, 
2).ok());
     ///
     /// // but not if the variant is not a decimal
     /// let v2 = Variant::from("hello!");
-    /// assert_eq!(v2.as_decimal_int128(), None);
+    /// assert_eq!(v2.as_decimal16(), None);
     /// ```
-    pub fn as_decimal_int128(&self) -> Option<(i128, u8)> {
+    pub fn as_decimal16(&self) -> Option<VariantDecimal16> {
         match *self {
-            Variant::Decimal4(decimal) => Some((decimal.integer().into(), 
decimal.scale())),
-            Variant::Decimal8(decimal) => Some((decimal.integer().into(), 
decimal.scale())),
-            Variant::Decimal16(decimal) => Some((decimal.integer(), 
decimal.scale())),
+            Variant::Int8(i) => i128::from(i).try_into().ok(),
+            Variant::Int16(i) => i128::from(i).try_into().ok(),
+            Variant::Int32(i) => i128::from(i).try_into().ok(),
+            Variant::Int64(i) => i128::from(i).try_into().ok(),
+            Variant::Decimal4(decimal4) => Some(decimal4.into()),
+            Variant::Decimal8(decimal8) => Some(decimal8.into()),
+            Variant::Decimal16(decimal16) => Some(decimal16),
             _ => None,
         }
     }
@@ -1035,17 +1041,14 @@ mod tests {
     fn test_variant_decimal_conversion() {
         let decimal4 = VariantDecimal4::try_new(1234_i32, 2).unwrap();
         let variant = Variant::from(decimal4);
-        assert_eq!(variant.as_decimal_int32(), Some((1234_i32, 2)));
+        assert_eq!(variant.as_decimal4(), Some(decimal4));
 
         let decimal8 = VariantDecimal8::try_new(12345678901_i64, 2).unwrap();
         let variant = Variant::from(decimal8);
-        assert_eq!(variant.as_decimal_int64(), Some((12345678901_i64, 2)));
+        assert_eq!(variant.as_decimal8(), Some(decimal8));
 
         let decimal16 = 
VariantDecimal16::try_new(123456789012345678901234567890_i128, 2).unwrap();
         let variant = Variant::from(decimal16);
-        assert_eq!(
-            variant.as_decimal_int128(),
-            Some((123456789012345678901234567890_i128, 2))
-        );
+        assert_eq!(variant.as_decimal16(), Some(decimal16));
     }
 }
diff --git a/parquet-variant/src/variant/decimal.rs 
b/parquet-variant/src/variant/decimal.rs
index 852d36c520..1a897d0668 100644
--- a/parquet-variant/src/variant/decimal.rs
+++ b/parquet-variant/src/variant/decimal.rs
@@ -17,13 +17,38 @@
 use arrow_schema::ArrowError;
 use std::fmt;
 
-// Macro to format decimal values, using only integer arithmetic to avoid 
floating point precision loss
+// All decimal types use the same try_new implementation
+macro_rules! decimal_try_new {
+    ($integer:ident, $scale:ident) => {{
+        // Validate that scale doesn't exceed precision
+        if $scale > Self::MAX_PRECISION {
+            return Err(ArrowError::InvalidArgumentError(format!(
+                "Scale {} is larger than max precision {}",
+                $scale,
+                Self::MAX_PRECISION,
+            )));
+        }
+
+        // Validate that the integer value fits within the precision
+        if $integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
+            return Err(ArrowError::InvalidArgumentError(format!(
+                "{} is wider than max precision {}",
+                $integer,
+                Self::MAX_PRECISION
+            )));
+        }
+
+        Ok(Self { $integer, $scale })
+    }};
+}
+
+// All decimal values format the same way, using integer arithmetic to avoid 
floating point precision loss
 macro_rules! format_decimal {
     ($f:expr, $integer:expr, $scale:expr, $int_type:ty) => {{
         let integer = if $scale == 0 {
             $integer
         } else {
-            let divisor = (10 as $int_type).pow($scale as u32);
+            let divisor = <$int_type>::pow(10, $scale as u32);
             let remainder = $integer % divisor;
             if remainder != 0 {
                 // Track the sign explicitly, in case the quotient is zero
@@ -61,29 +86,11 @@ pub struct VariantDecimal4 {
 }
 
 impl VariantDecimal4 {
-    const MAX_PRECISION: u32 = 9;
-    const MAX_UNSCALED_VALUE: u32 = 10_u32.pow(Self::MAX_PRECISION) - 1;
+    const MAX_PRECISION: u8 = 9;
+    const MAX_UNSCALED_VALUE: u32 = u32::pow(10, Self::MAX_PRECISION as u32) - 
1;
 
     pub fn try_new(integer: i32, scale: u8) -> Result<Self, ArrowError> {
-        // Validate that scale doesn't exceed precision
-        if scale as u32 > Self::MAX_PRECISION {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "Scale {} of a 4-byte decimal cannot exceed the max precision 
{}",
-                scale,
-                Self::MAX_PRECISION,
-            )));
-        }
-
-        // Validate that the integer value fits within the precision
-        if integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "{} is too large to store in a 4-byte decimal with max 
precision {}",
-                integer,
-                Self::MAX_PRECISION
-            )));
-        }
-
-        Ok(VariantDecimal4 { integer, scale })
+        decimal_try_new!(integer, scale)
     }
 
     /// Returns the underlying value of the decimal.
@@ -129,29 +136,11 @@ pub struct VariantDecimal8 {
 }
 
 impl VariantDecimal8 {
-    const MAX_PRECISION: u32 = 18;
-    const MAX_UNSCALED_VALUE: u64 = 10_u64.pow(Self::MAX_PRECISION) - 1;
+    const MAX_PRECISION: u8 = 18;
+    const MAX_UNSCALED_VALUE: u64 = u64::pow(10, Self::MAX_PRECISION as u32) - 
1;
 
     pub fn try_new(integer: i64, scale: u8) -> Result<Self, ArrowError> {
-        // Validate that scale doesn't exceed precision
-        if scale as u32 > Self::MAX_PRECISION {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "Scale {} of an 8-byte decimal cannot exceed the max precision 
{}",
-                scale,
-                Self::MAX_PRECISION,
-            )));
-        }
-
-        // Validate that the integer value fits within the precision
-        if integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "{} is too large to store in an 8-byte decimal with max 
precision {}",
-                integer,
-                Self::MAX_PRECISION
-            )));
-        }
-
-        Ok(VariantDecimal8 { integer, scale })
+        decimal_try_new!(integer, scale)
     }
 
     /// Returns the underlying value of the decimal.
@@ -197,29 +186,11 @@ pub struct VariantDecimal16 {
 }
 
 impl VariantDecimal16 {
-    const MAX_PRECISION: u32 = 38;
-    const MAX_UNSCALED_VALUE: u128 = 10_u128.pow(Self::MAX_PRECISION) - 1;
+    const MAX_PRECISION: u8 = 38;
+    const MAX_UNSCALED_VALUE: u128 = u128::pow(10, Self::MAX_PRECISION as u32) 
- 1;
 
     pub fn try_new(integer: i128, scale: u8) -> Result<Self, ArrowError> {
-        // Validate that scale doesn't exceed precision
-        if scale as u32 > Self::MAX_PRECISION {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "Scale {} of a 16-byte decimal cannot exceed the max precision 
{}",
-                scale,
-                Self::MAX_PRECISION,
-            )));
-        }
-
-        // Validate that the integer value fits within the precision
-        if integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "{} is too large to store in a 16-byte decimal with max 
precision {}",
-                integer,
-                Self::MAX_PRECISION
-            )));
-        }
-
-        Ok(VariantDecimal16 { integer, scale })
+        decimal_try_new!(integer, scale)
     }
 
     /// Returns the underlying value of the decimal.
@@ -243,6 +214,65 @@ impl fmt::Display for VariantDecimal16 {
     }
 }
 
+// Infallible conversion from a narrower decimal type to a wider one
+macro_rules! impl_from_decimal_for_decimal {
+    ($from_ty:ty, $for_ty:ty) => {
+        impl From<$from_ty> for $for_ty {
+            fn from(decimal: $from_ty) -> Self {
+                Self {
+                    integer: decimal.integer.into(),
+                    scale: decimal.scale,
+                }
+            }
+        }
+    };
+}
+
+impl_from_decimal_for_decimal!(VariantDecimal4, VariantDecimal8);
+impl_from_decimal_for_decimal!(VariantDecimal4, VariantDecimal16);
+impl_from_decimal_for_decimal!(VariantDecimal8, VariantDecimal16);
+
+// Fallible conversion from a wider decimal type to a narrower one
+macro_rules! impl_try_from_decimal_for_decimal {
+    ($from_ty:ty, $for_ty:ty) => {
+        impl TryFrom<$from_ty> for $for_ty {
+            type Error = ArrowError;
+
+            fn try_from(decimal: $from_ty) -> Result<Self, ArrowError> {
+                let Ok(integer) = decimal.integer.try_into() else {
+                    return Err(ArrowError::InvalidArgumentError(format!(
+                        "Value {} is wider than max precision {}",
+                        decimal.integer,
+                        Self::MAX_PRECISION
+                    )));
+                };
+                Self::try_new(integer, decimal.scale)
+            }
+        }
+    };
+}
+
+impl_try_from_decimal_for_decimal!(VariantDecimal8, VariantDecimal4);
+impl_try_from_decimal_for_decimal!(VariantDecimal16, VariantDecimal4);
+impl_try_from_decimal_for_decimal!(VariantDecimal16, VariantDecimal8);
+
+// Fallible conversion from a decimal's underlying integer type
+macro_rules! impl_try_from_int_for_decimal {
+    ($from_ty:ty, $for_ty:ty) => {
+        impl TryFrom<$from_ty> for $for_ty {
+            type Error = ArrowError;
+
+            fn try_from(integer: $from_ty) -> Result<Self, ArrowError> {
+                Self::try_new(integer, 0)
+            }
+        }
+    };
+}
+
+impl_try_from_int_for_decimal!(i32, VariantDecimal4);
+impl_try_from_int_for_decimal!(i64, VariantDecimal8);
+impl_try_from_int_for_decimal!(i128, VariantDecimal16);
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -258,7 +288,7 @@ mod tests {
         assert!(decimal4_too_large
             .unwrap_err()
             .to_string()
-            .contains("too large"));
+            .contains("wider than max precision"));
 
         let decimal4_too_small = VariantDecimal4::try_new(-1_000_000_000_i32, 
2);
         assert!(
@@ -268,7 +298,7 @@ mod tests {
         assert!(decimal4_too_small
             .unwrap_err()
             .to_string()
-            .contains("too large"));
+            .contains("wider than max precision"));
 
         // Test valid edge cases for Decimal4
         let decimal4_max_valid = VariantDecimal4::try_new(999_999_999_i32, 2);
@@ -292,7 +322,7 @@ mod tests {
         assert!(decimal8_too_large
             .unwrap_err()
             .to_string()
-            .contains("too large"));
+            .contains("wider than max precision"));
 
         let decimal8_too_small = 
VariantDecimal8::try_new(-1_000_000_000_000_000_000_i64, 2);
         assert!(
@@ -302,7 +332,7 @@ mod tests {
         assert!(decimal8_too_small
             .unwrap_err()
             .to_string()
-            .contains("too large"));
+            .contains("wider than max precision"));
 
         // Test valid edge cases for Decimal8
         let decimal8_max_valid = 
VariantDecimal8::try_new(999_999_999_999_999_999_i64, 2);
@@ -327,7 +357,7 @@ mod tests {
         assert!(decimal16_too_large
             .unwrap_err()
             .to_string()
-            .contains("too large"));
+            .contains("wider than max precision"));
 
         let decimal16_too_small =
             
VariantDecimal16::try_new(-100000000000000000000000000000000000000_i128, 2);
@@ -338,7 +368,7 @@ mod tests {
         assert!(decimal16_too_small
             .unwrap_err()
             .to_string()
-            .contains("too large"));
+            .contains("wider than max precision"));
 
         // Test valid edge cases for Decimal16
         let decimal16_max_valid =
@@ -367,7 +397,7 @@ mod tests {
         assert!(decimal4_invalid_scale
             .unwrap_err()
             .to_string()
-            .contains("cannot exceed the max precision"));
+            .contains("larger than max precision"));
 
         let decimal4_invalid_scale_large = VariantDecimal4::try_new(123_i32, 
20);
         assert!(
@@ -391,7 +421,7 @@ mod tests {
         assert!(decimal8_invalid_scale
             .unwrap_err()
             .to_string()
-            .contains("cannot exceed the max precision"));
+            .contains("larger than max precision"));
 
         let decimal8_invalid_scale_large = VariantDecimal8::try_new(123_i64, 
25);
         assert!(
@@ -415,7 +445,7 @@ mod tests {
         assert!(decimal16_invalid_scale
             .unwrap_err()
             .to_string()
-            .contains("cannot exceed the max precision"));
+            .contains("larger than max precision"));
 
         let decimal16_invalid_scale_large = 
VariantDecimal16::try_new(123_i128, 50);
         assert!(

Reply via email to