This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 959577deab [Variant] impl [Try]From for VariantDecimalXX types (#7809)
959577deab is described below
commit 959577deabf6e27524cc7624e45e58bc2723f478
Author: Ryan Johnson <[email protected]>
AuthorDate: Tue Jul 1 04:06:01 2025 -0700
[Variant] impl [Try]From for VariantDecimalXX types (#7809)
# Which issue does this PR close?
- Part of https://github.com/apache/arrow-rs/issues/6736
# Rationale for this change
The existing `Variant::as_decimal_XX` methods were actually incorrect,
failing to validate scale when converting from a wider decimal to a
narrower one. Fix it, while also improving ergonomics of the decimal
code to reduce the chances of future issues of this type.
# What changes are included in this PR?
Add proper [Try]From for converting to decimal types from other decimals
or their underlying integer type.
Add missing conversions in the `Variant::as_int_xx` and
`Variant::as_decimal_xx` helpers.
# Are these changes tested?
TODO - need more tests for the new conversions
# Are there any user-facing changes?
The `Variant:as_decimal_xx` methods have been renamed and now return
actual decimal types.
New conversions available.
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
parquet-variant/src/variant.rs | 103 ++++++++++---------
parquet-variant/src/variant/decimal.rs | 178 +++++++++++++++++++--------------
2 files changed, 157 insertions(+), 124 deletions(-)
diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs
index 3dcb08053a..36564c2bff 100644
--- a/parquet-variant/src/variant.rs
+++ b/parquet-variant/src/variant.rs
@@ -538,6 +538,9 @@ impl<'m, 'v> Variant<'m, 'v> {
Variant::Int16(i) => i.try_into().ok(),
Variant::Int32(i) => i.try_into().ok(),
Variant::Int64(i) => i.try_into().ok(),
+ Variant::Decimal4(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
+ Variant::Decimal8(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
+ Variant::Decimal16(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
_ => None,
}
}
@@ -570,6 +573,9 @@ impl<'m, 'v> Variant<'m, 'v> {
Variant::Int16(i) => Some(i),
Variant::Int32(i) => i.try_into().ok(),
Variant::Int64(i) => i.try_into().ok(),
+ Variant::Decimal4(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
+ Variant::Decimal8(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
+ Variant::Decimal16(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
_ => None,
}
}
@@ -602,6 +608,9 @@ impl<'m, 'v> Variant<'m, 'v> {
Variant::Int16(i) => Some(i.into()),
Variant::Int32(i) => Some(i),
Variant::Int64(i) => i.try_into().ok(),
+ Variant::Decimal4(d) if d.scale() == 0 => Some(d.integer()),
+ Variant::Decimal8(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
+ Variant::Decimal16(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
_ => None,
}
}
@@ -630,6 +639,9 @@ impl<'m, 'v> Variant<'m, 'v> {
Variant::Int16(i) => Some(i.into()),
Variant::Int32(i) => Some(i.into()),
Variant::Int64(i) => Some(i),
+ Variant::Decimal4(d) if d.scale() == 0 => Some(d.integer().into()),
+ Variant::Decimal8(d) if d.scale() == 0 => Some(d.integer()),
+ Variant::Decimal16(d) if d.scale() == 0 =>
d.integer().try_into().ok(),
_ => None,
}
}
@@ -647,37 +659,29 @@ impl<'m, 'v> Variant<'m, 'v> {
///
/// // you can extract decimal parts from smaller or equally-sized decimal
variants
/// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap());
- /// assert_eq!(v1.as_decimal_int32(), Some((1234_i32, 2)));
+ /// assert_eq!(v1.as_decimal4(), VariantDecimal4::try_new(1234_i32,
2).ok());
///
/// // and from larger decimal variants if they fit
/// let v2 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap());
- /// assert_eq!(v2.as_decimal_int32(), Some((1234_i32, 2)));
+ /// assert_eq!(v2.as_decimal4(), VariantDecimal4::try_new(1234_i32,
2).ok());
///
/// // but not if the value would overflow i32
/// let v3 = Variant::from(VariantDecimal8::try_new(12345678901i64,
2).unwrap());
- /// assert_eq!(v3.as_decimal_int32(), None);
+ /// assert_eq!(v3.as_decimal4(), None);
///
/// // or if the variant is not a decimal
/// let v4 = Variant::from("hello!");
- /// assert_eq!(v4.as_decimal_int32(), None);
+ /// assert_eq!(v4.as_decimal4(), None);
/// ```
- pub fn as_decimal_int32(&self) -> Option<(i32, u8)> {
+ pub fn as_decimal4(&self) -> Option<VariantDecimal4> {
match *self {
- Variant::Decimal4(decimal4) => Some((decimal4.integer(),
decimal4.scale())),
- Variant::Decimal8(decimal8) => {
- if let Ok(converted_integer) = decimal8.integer().try_into() {
- Some((converted_integer, decimal8.scale()))
- } else {
- None
- }
- }
- Variant::Decimal16(decimal16) => {
- if let Ok(converted_integer) = decimal16.integer().try_into() {
- Some((converted_integer, decimal16.scale()))
- } else {
- None
- }
- }
+ Variant::Int8(i) => i32::from(i).try_into().ok(),
+ Variant::Int16(i) => i32::from(i).try_into().ok(),
+ Variant::Int32(i) => i.try_into().ok(),
+ Variant::Int64(i) => i32::try_from(i).ok()?.try_into().ok(),
+ Variant::Decimal4(decimal4) => Some(decimal4),
+ Variant::Decimal8(decimal8) => decimal8.try_into().ok(),
+ Variant::Decimal16(decimal16) => decimal16.try_into().ok(),
_ => None,
}
}
@@ -691,35 +695,33 @@ impl<'m, 'v> Variant<'m, 'v> {
/// # Examples
///
/// ```
- /// use parquet_variant::{Variant, VariantDecimal8, VariantDecimal16};
+ /// use parquet_variant::{Variant, VariantDecimal4, VariantDecimal8,
VariantDecimal16};
///
/// // you can extract decimal parts from smaller or equally-sized decimal
variants
- /// let v1 = Variant::from(VariantDecimal8::try_new(1234_i64, 2).unwrap());
- /// assert_eq!(v1.as_decimal_int64(), Some((1234_i64, 2)));
+ /// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap());
+ /// assert_eq!(v1.as_decimal8(), VariantDecimal8::try_new(1234_i64,
2).ok());
///
/// // and from larger decimal variants if they fit
/// let v2 = Variant::from(VariantDecimal16::try_new(1234_i128,
2).unwrap());
- /// assert_eq!(v2.as_decimal_int64(), Some((1234_i64, 2)));
+ /// assert_eq!(v2.as_decimal8(), VariantDecimal8::try_new(1234_i64,
2).ok());
///
/// // but not if the value would overflow i64
/// let v3 = Variant::from(VariantDecimal16::try_new(2e19 as i128,
2).unwrap());
- /// assert_eq!(v3.as_decimal_int64(), None);
+ /// assert_eq!(v3.as_decimal8(), None);
///
/// // or if the variant is not a decimal
/// let v4 = Variant::from("hello!");
- /// assert_eq!(v4.as_decimal_int64(), None);
+ /// assert_eq!(v4.as_decimal8(), None);
/// ```
- pub fn as_decimal_int64(&self) -> Option<(i64, u8)> {
+ pub fn as_decimal8(&self) -> Option<VariantDecimal8> {
match *self {
- Variant::Decimal4(decimal) => Some((decimal.integer().into(),
decimal.scale())),
- Variant::Decimal8(decimal) => Some((decimal.integer(),
decimal.scale())),
- Variant::Decimal16(decimal) => {
- if let Ok(converted_integer) = decimal.integer().try_into() {
- Some((converted_integer, decimal.scale()))
- } else {
- None
- }
- }
+ Variant::Int8(i) => i64::from(i).try_into().ok(),
+ Variant::Int16(i) => i64::from(i).try_into().ok(),
+ Variant::Int32(i) => i64::from(i).try_into().ok(),
+ Variant::Int64(i) => i.try_into().ok(),
+ Variant::Decimal4(decimal4) => Some(decimal4.into()),
+ Variant::Decimal8(decimal8) => Some(decimal8),
+ Variant::Decimal16(decimal16) => decimal16.try_into().ok(),
_ => None,
}
}
@@ -733,21 +735,25 @@ impl<'m, 'v> Variant<'m, 'v> {
/// # Examples
///
/// ```
- /// use parquet_variant::{Variant, VariantDecimal16};
+ /// use parquet_variant::{Variant, VariantDecimal16, VariantDecimal4};
///
/// // you can extract decimal parts from smaller or equally-sized decimal
variants
- /// let v1 = Variant::from(VariantDecimal16::try_new(1234_i128,
2).unwrap());
- /// assert_eq!(v1.as_decimal_int128(), Some((1234_i128, 2)));
+ /// let v1 = Variant::from(VariantDecimal4::try_new(1234_i32, 2).unwrap());
+ /// assert_eq!(v1.as_decimal16(), VariantDecimal16::try_new(1234_i128,
2).ok());
///
/// // but not if the variant is not a decimal
/// let v2 = Variant::from("hello!");
- /// assert_eq!(v2.as_decimal_int128(), None);
+ /// assert_eq!(v2.as_decimal16(), None);
/// ```
- pub fn as_decimal_int128(&self) -> Option<(i128, u8)> {
+ pub fn as_decimal16(&self) -> Option<VariantDecimal16> {
match *self {
- Variant::Decimal4(decimal) => Some((decimal.integer().into(),
decimal.scale())),
- Variant::Decimal8(decimal) => Some((decimal.integer().into(),
decimal.scale())),
- Variant::Decimal16(decimal) => Some((decimal.integer(),
decimal.scale())),
+ Variant::Int8(i) => i128::from(i).try_into().ok(),
+ Variant::Int16(i) => i128::from(i).try_into().ok(),
+ Variant::Int32(i) => i128::from(i).try_into().ok(),
+ Variant::Int64(i) => i128::from(i).try_into().ok(),
+ Variant::Decimal4(decimal4) => Some(decimal4.into()),
+ Variant::Decimal8(decimal8) => Some(decimal8.into()),
+ Variant::Decimal16(decimal16) => Some(decimal16),
_ => None,
}
}
@@ -1035,17 +1041,14 @@ mod tests {
fn test_variant_decimal_conversion() {
let decimal4 = VariantDecimal4::try_new(1234_i32, 2).unwrap();
let variant = Variant::from(decimal4);
- assert_eq!(variant.as_decimal_int32(), Some((1234_i32, 2)));
+ assert_eq!(variant.as_decimal4(), Some(decimal4));
let decimal8 = VariantDecimal8::try_new(12345678901_i64, 2).unwrap();
let variant = Variant::from(decimal8);
- assert_eq!(variant.as_decimal_int64(), Some((12345678901_i64, 2)));
+ assert_eq!(variant.as_decimal8(), Some(decimal8));
let decimal16 =
VariantDecimal16::try_new(123456789012345678901234567890_i128, 2).unwrap();
let variant = Variant::from(decimal16);
- assert_eq!(
- variant.as_decimal_int128(),
- Some((123456789012345678901234567890_i128, 2))
- );
+ assert_eq!(variant.as_decimal16(), Some(decimal16));
}
}
diff --git a/parquet-variant/src/variant/decimal.rs
b/parquet-variant/src/variant/decimal.rs
index 852d36c520..1a897d0668 100644
--- a/parquet-variant/src/variant/decimal.rs
+++ b/parquet-variant/src/variant/decimal.rs
@@ -17,13 +17,38 @@
use arrow_schema::ArrowError;
use std::fmt;
-// Macro to format decimal values, using only integer arithmetic to avoid
floating point precision loss
+// All decimal types use the same try_new implementation
+macro_rules! decimal_try_new {
+ ($integer:ident, $scale:ident) => {{
+ // Validate that scale doesn't exceed precision
+ if $scale > Self::MAX_PRECISION {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Scale {} is larger than max precision {}",
+ $scale,
+ Self::MAX_PRECISION,
+ )));
+ }
+
+ // Validate that the integer value fits within the precision
+ if $integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "{} is wider than max precision {}",
+ $integer,
+ Self::MAX_PRECISION
+ )));
+ }
+
+ Ok(Self { $integer, $scale })
+ }};
+}
+
+// All decimal values format the same way, using integer arithmetic to avoid
floating point precision loss
macro_rules! format_decimal {
($f:expr, $integer:expr, $scale:expr, $int_type:ty) => {{
let integer = if $scale == 0 {
$integer
} else {
- let divisor = (10 as $int_type).pow($scale as u32);
+ let divisor = <$int_type>::pow(10, $scale as u32);
let remainder = $integer % divisor;
if remainder != 0 {
// Track the sign explicitly, in case the quotient is zero
@@ -61,29 +86,11 @@ pub struct VariantDecimal4 {
}
impl VariantDecimal4 {
- const MAX_PRECISION: u32 = 9;
- const MAX_UNSCALED_VALUE: u32 = 10_u32.pow(Self::MAX_PRECISION) - 1;
+ const MAX_PRECISION: u8 = 9;
+ const MAX_UNSCALED_VALUE: u32 = u32::pow(10, Self::MAX_PRECISION as u32) -
1;
pub fn try_new(integer: i32, scale: u8) -> Result<Self, ArrowError> {
- // Validate that scale doesn't exceed precision
- if scale as u32 > Self::MAX_PRECISION {
- return Err(ArrowError::InvalidArgumentError(format!(
- "Scale {} of a 4-byte decimal cannot exceed the max precision
{}",
- scale,
- Self::MAX_PRECISION,
- )));
- }
-
- // Validate that the integer value fits within the precision
- if integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
- return Err(ArrowError::InvalidArgumentError(format!(
- "{} is too large to store in a 4-byte decimal with max
precision {}",
- integer,
- Self::MAX_PRECISION
- )));
- }
-
- Ok(VariantDecimal4 { integer, scale })
+ decimal_try_new!(integer, scale)
}
/// Returns the underlying value of the decimal.
@@ -129,29 +136,11 @@ pub struct VariantDecimal8 {
}
impl VariantDecimal8 {
- const MAX_PRECISION: u32 = 18;
- const MAX_UNSCALED_VALUE: u64 = 10_u64.pow(Self::MAX_PRECISION) - 1;
+ const MAX_PRECISION: u8 = 18;
+ const MAX_UNSCALED_VALUE: u64 = u64::pow(10, Self::MAX_PRECISION as u32) -
1;
pub fn try_new(integer: i64, scale: u8) -> Result<Self, ArrowError> {
- // Validate that scale doesn't exceed precision
- if scale as u32 > Self::MAX_PRECISION {
- return Err(ArrowError::InvalidArgumentError(format!(
- "Scale {} of an 8-byte decimal cannot exceed the max precision
{}",
- scale,
- Self::MAX_PRECISION,
- )));
- }
-
- // Validate that the integer value fits within the precision
- if integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
- return Err(ArrowError::InvalidArgumentError(format!(
- "{} is too large to store in an 8-byte decimal with max
precision {}",
- integer,
- Self::MAX_PRECISION
- )));
- }
-
- Ok(VariantDecimal8 { integer, scale })
+ decimal_try_new!(integer, scale)
}
/// Returns the underlying value of the decimal.
@@ -197,29 +186,11 @@ pub struct VariantDecimal16 {
}
impl VariantDecimal16 {
- const MAX_PRECISION: u32 = 38;
- const MAX_UNSCALED_VALUE: u128 = 10_u128.pow(Self::MAX_PRECISION) - 1;
+ const MAX_PRECISION: u8 = 38;
+ const MAX_UNSCALED_VALUE: u128 = u128::pow(10, Self::MAX_PRECISION as u32)
- 1;
pub fn try_new(integer: i128, scale: u8) -> Result<Self, ArrowError> {
- // Validate that scale doesn't exceed precision
- if scale as u32 > Self::MAX_PRECISION {
- return Err(ArrowError::InvalidArgumentError(format!(
- "Scale {} of a 16-byte decimal cannot exceed the max precision
{}",
- scale,
- Self::MAX_PRECISION,
- )));
- }
-
- // Validate that the integer value fits within the precision
- if integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
- return Err(ArrowError::InvalidArgumentError(format!(
- "{} is too large to store in a 16-byte decimal with max
precision {}",
- integer,
- Self::MAX_PRECISION
- )));
- }
-
- Ok(VariantDecimal16 { integer, scale })
+ decimal_try_new!(integer, scale)
}
/// Returns the underlying value of the decimal.
@@ -243,6 +214,65 @@ impl fmt::Display for VariantDecimal16 {
}
}
+// Infallible conversion from a narrower decimal type to a wider one
+macro_rules! impl_from_decimal_for_decimal {
+ ($from_ty:ty, $for_ty:ty) => {
+ impl From<$from_ty> for $for_ty {
+ fn from(decimal: $from_ty) -> Self {
+ Self {
+ integer: decimal.integer.into(),
+ scale: decimal.scale,
+ }
+ }
+ }
+ };
+}
+
+impl_from_decimal_for_decimal!(VariantDecimal4, VariantDecimal8);
+impl_from_decimal_for_decimal!(VariantDecimal4, VariantDecimal16);
+impl_from_decimal_for_decimal!(VariantDecimal8, VariantDecimal16);
+
+// Fallible conversion from a wider decimal type to a narrower one
+macro_rules! impl_try_from_decimal_for_decimal {
+ ($from_ty:ty, $for_ty:ty) => {
+ impl TryFrom<$from_ty> for $for_ty {
+ type Error = ArrowError;
+
+ fn try_from(decimal: $from_ty) -> Result<Self, ArrowError> {
+ let Ok(integer) = decimal.integer.try_into() else {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Value {} is wider than max precision {}",
+ decimal.integer,
+ Self::MAX_PRECISION
+ )));
+ };
+ Self::try_new(integer, decimal.scale)
+ }
+ }
+ };
+}
+
+impl_try_from_decimal_for_decimal!(VariantDecimal8, VariantDecimal4);
+impl_try_from_decimal_for_decimal!(VariantDecimal16, VariantDecimal4);
+impl_try_from_decimal_for_decimal!(VariantDecimal16, VariantDecimal8);
+
+// Fallible conversion from a decimal's underlying integer type
+macro_rules! impl_try_from_int_for_decimal {
+ ($from_ty:ty, $for_ty:ty) => {
+ impl TryFrom<$from_ty> for $for_ty {
+ type Error = ArrowError;
+
+ fn try_from(integer: $from_ty) -> Result<Self, ArrowError> {
+ Self::try_new(integer, 0)
+ }
+ }
+ };
+}
+
+impl_try_from_int_for_decimal!(i32, VariantDecimal4);
+impl_try_from_int_for_decimal!(i64, VariantDecimal8);
+impl_try_from_int_for_decimal!(i128, VariantDecimal16);
+
#[cfg(test)]
mod tests {
use super::*;
@@ -258,7 +288,7 @@ mod tests {
assert!(decimal4_too_large
.unwrap_err()
.to_string()
- .contains("too large"));
+ .contains("wider than max precision"));
let decimal4_too_small = VariantDecimal4::try_new(-1_000_000_000_i32,
2);
assert!(
@@ -268,7 +298,7 @@ mod tests {
assert!(decimal4_too_small
.unwrap_err()
.to_string()
- .contains("too large"));
+ .contains("wider than max precision"));
// Test valid edge cases for Decimal4
let decimal4_max_valid = VariantDecimal4::try_new(999_999_999_i32, 2);
@@ -292,7 +322,7 @@ mod tests {
assert!(decimal8_too_large
.unwrap_err()
.to_string()
- .contains("too large"));
+ .contains("wider than max precision"));
let decimal8_too_small =
VariantDecimal8::try_new(-1_000_000_000_000_000_000_i64, 2);
assert!(
@@ -302,7 +332,7 @@ mod tests {
assert!(decimal8_too_small
.unwrap_err()
.to_string()
- .contains("too large"));
+ .contains("wider than max precision"));
// Test valid edge cases for Decimal8
let decimal8_max_valid =
VariantDecimal8::try_new(999_999_999_999_999_999_i64, 2);
@@ -327,7 +357,7 @@ mod tests {
assert!(decimal16_too_large
.unwrap_err()
.to_string()
- .contains("too large"));
+ .contains("wider than max precision"));
let decimal16_too_small =
VariantDecimal16::try_new(-100000000000000000000000000000000000000_i128, 2);
@@ -338,7 +368,7 @@ mod tests {
assert!(decimal16_too_small
.unwrap_err()
.to_string()
- .contains("too large"));
+ .contains("wider than max precision"));
// Test valid edge cases for Decimal16
let decimal16_max_valid =
@@ -367,7 +397,7 @@ mod tests {
assert!(decimal4_invalid_scale
.unwrap_err()
.to_string()
- .contains("cannot exceed the max precision"));
+ .contains("larger than max precision"));
let decimal4_invalid_scale_large = VariantDecimal4::try_new(123_i32,
20);
assert!(
@@ -391,7 +421,7 @@ mod tests {
assert!(decimal8_invalid_scale
.unwrap_err()
.to_string()
- .contains("cannot exceed the max precision"));
+ .contains("larger than max precision"));
let decimal8_invalid_scale_large = VariantDecimal8::try_new(123_i64,
25);
assert!(
@@ -415,7 +445,7 @@ mod tests {
assert!(decimal16_invalid_scale
.unwrap_err()
.to_string()
- .contains("cannot exceed the max precision"));
+ .contains("larger than max precision"));
let decimal16_invalid_scale_large =
VariantDecimal16::try_new(123_i128, 50);
assert!(