This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new b8fdd90728 [Variant] Define and use VariantDecimalType trait (#8562)
b8fdd90728 is described below

commit b8fdd907283ab0a3e4ef1b4647df2a49dcbe647f
Author: Ryan Johnson <[email protected]>
AuthorDate: Tue Oct 14 10:10:17 2025 -0600

    [Variant] Define and use VariantDecimalType trait (#8562)
    
    # Which issue does this PR close?
    
    We generally require a GitHub issue to be filed for all bug fixes and
    enhancements and this helps us generate change logs for our releases.
    You can link an issue to this PR using the GitHub syntax.
    
    - Closes #NNN.
    
    # Rationale for this change
    
    `VariantDecimalXX` structs are structurally near-identical but lack any
    trait to that can expose that regularity.
    
    # What changes are included in this PR?
    
    Define and use a new `VariantDecimalType` trait that exposes common
    functionality of all three variant decimal types.
    
    # Are these changes tested?
    
    Yes, existing unit tests cover the changes.
    
    # Are there any user-facing changes?
    
    New pub trait.
---
 parquet-variant-compute/src/arrow_to_variant.rs | 120 ++++------
 parquet-variant-compute/src/type_conversion.rs  |  19 --
 parquet-variant-compute/src/unshred_variant.rs  | 150 ++++--------
 parquet-variant-compute/src/variant_array.rs    |  32 +--
 parquet-variant/src/variant.rs                  |   2 +-
 parquet-variant/src/variant/decimal.rs          | 302 +++++++++++++++---------
 6 files changed, 293 insertions(+), 332 deletions(-)

diff --git a/parquet-variant-compute/src/arrow_to_variant.rs 
b/parquet-variant-compute/src/arrow_to_variant.rs
index fe0c521090..5e01aba3c1 100644
--- a/parquet-variant-compute/src/arrow_to_variant.rs
+++ b/parquet-variant-compute/src/arrow_to_variant.rs
@@ -15,25 +15,22 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::type_conversion::{CastOptions, decimal_to_variant_decimal};
+use crate::type_conversion::CastOptions;
 use arrow::array::{
     Array, AsArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, 
GenericListViewArray,
     GenericStringArray, OffsetSizeTrait, PrimitiveArray,
 };
 use arrow::compute::kernels::cast;
 use arrow::datatypes::{
-    ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, 
ArrowTimestampType, Date32Type,
-    Date64Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, 
Int32Type, Int64Type,
-    RunEndIndexType, Time32MillisecondType, Time32SecondType, 
Time64MicrosecondType,
-    Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType,
-    TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, 
UInt32Type, UInt64Type,
+    self as datatypes, ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, 
ArrowTimestampType,
+    DecimalType, RunEndIndexType,
 };
 use arrow::temporal_conversions::{as_date, as_datetime, as_time};
 use arrow_schema::{ArrowError, DataType, TimeUnit};
 use chrono::{DateTime, TimeZone, Utc};
 use parquet_variant::{
     ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, 
VariantDecimal8,
-    VariantDecimal16,
+    VariantDecimal16, VariantDecimalType,
 };
 use std::collections::HashMap;
 use std::ops::Range;
@@ -46,31 +43,31 @@ use std::ops::Range;
 pub(crate) enum ArrowToVariantRowBuilder<'a> {
     Null(NullArrowToVariantBuilder),
     Boolean(BooleanArrowToVariantBuilder<'a>),
-    PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, Int8Type>),
-    PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, Int16Type>),
-    PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, Int32Type>),
-    PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, Int64Type>),
-    PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, UInt8Type>),
-    PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, UInt16Type>),
-    PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, UInt32Type>),
-    PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, UInt64Type>),
-    PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, Float16Type>),
-    PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, Float32Type>),
-    PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, Float64Type>),
-    Decimal32(Decimal32ArrowToVariantBuilder<'a>),
-    Decimal64(Decimal64ArrowToVariantBuilder<'a>),
-    Decimal128(Decimal128ArrowToVariantBuilder<'a>),
+    PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, datatypes::Int8Type>),
+    PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, datatypes::Int16Type>),
+    PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, datatypes::Int32Type>),
+    PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, datatypes::Int64Type>),
+    PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt8Type>),
+    PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt16Type>),
+    PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt32Type>),
+    PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt64Type>),
+    PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, 
datatypes::Float16Type>),
+    PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, 
datatypes::Float32Type>),
+    PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, 
datatypes::Float64Type>),
+    Decimal32(DecimalArrowToVariantBuilder<'a, datatypes::Decimal32Type, 
VariantDecimal4>),
+    Decimal64(DecimalArrowToVariantBuilder<'a, datatypes::Decimal64Type, 
VariantDecimal8>),
+    Decimal128(DecimalArrowToVariantBuilder<'a, datatypes::Decimal128Type, 
VariantDecimal16>),
     Decimal256(Decimal256ArrowToVariantBuilder<'a>),
-    TimestampSecond(TimestampArrowToVariantBuilder<'a, TimestampSecondType>),
-    TimestampMillisecond(TimestampArrowToVariantBuilder<'a, 
TimestampMillisecondType>),
-    TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, 
TimestampMicrosecondType>),
-    TimestampNanosecond(TimestampArrowToVariantBuilder<'a, 
TimestampNanosecondType>),
-    Date32(DateArrowToVariantBuilder<'a, Date32Type>),
-    Date64(DateArrowToVariantBuilder<'a, Date64Type>),
-    Time32Second(TimeArrowToVariantBuilder<'a, Time32SecondType>),
-    Time32Millisecond(TimeArrowToVariantBuilder<'a, Time32MillisecondType>),
-    Time64Microsecond(TimeArrowToVariantBuilder<'a, Time64MicrosecondType>),
-    Time64Nanosecond(TimeArrowToVariantBuilder<'a, Time64NanosecondType>),
+    TimestampSecond(TimestampArrowToVariantBuilder<'a, 
datatypes::TimestampSecondType>),
+    TimestampMillisecond(TimestampArrowToVariantBuilder<'a, 
datatypes::TimestampMillisecondType>),
+    TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, 
datatypes::TimestampMicrosecondType>),
+    TimestampNanosecond(TimestampArrowToVariantBuilder<'a, 
datatypes::TimestampNanosecondType>),
+    Date32(DateArrowToVariantBuilder<'a, datatypes::Date32Type>),
+    Date64(DateArrowToVariantBuilder<'a, datatypes::Date64Type>),
+    Time32Second(TimeArrowToVariantBuilder<'a, datatypes::Time32SecondType>),
+    Time32Millisecond(TimeArrowToVariantBuilder<'a, 
datatypes::Time32MillisecondType>),
+    Time64Microsecond(TimeArrowToVariantBuilder<'a, 
datatypes::Time64MicrosecondType>),
+    Time64Nanosecond(TimeArrowToVariantBuilder<'a, 
datatypes::Time64NanosecondType>),
     Binary(BinaryArrowToVariantBuilder<'a, i32>),
     LargeBinary(BinaryArrowToVariantBuilder<'a, i64>),
     BinaryView(BinaryViewArrowToVariantBuilder<'a>),
@@ -87,9 +84,9 @@ pub(crate) enum ArrowToVariantRowBuilder<'a> {
     Map(MapArrowToVariantBuilder<'a>),
     Union(UnionArrowToVariantBuilder<'a>),
     Dictionary(DictionaryArrowToVariantBuilder<'a>),
-    RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, Int16Type>),
-    RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, Int32Type>),
-    RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, Int64Type>),
+    RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, 
datatypes::Int16Type>),
+    RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, 
datatypes::Int32Type>),
+    RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, 
datatypes::Int64Type>),
 }
 
 impl<'a> ArrowToVariantRowBuilder<'a> {
@@ -174,13 +171,13 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>(
             DataType::Float32 => 
PrimitiveFloat32(PrimitiveArrowToVariantBuilder::new(array)),
             DataType::Float64 => 
PrimitiveFloat64(PrimitiveArrowToVariantBuilder::new(array)),
             DataType::Decimal32(_, scale) => {
-                Decimal32(Decimal32ArrowToVariantBuilder::new(array, options, 
*scale))
+                Decimal32(DecimalArrowToVariantBuilder::new(array, options, 
*scale))
             }
             DataType::Decimal64(_, scale) => {
-                Decimal64(Decimal64ArrowToVariantBuilder::new(array, options, 
*scale))
+                Decimal64(DecimalArrowToVariantBuilder::new(array, options, 
*scale))
             }
             DataType::Decimal128(_, scale) => {
-                Decimal128(Decimal128ArrowToVariantBuilder::new(array, 
options, *scale))
+                Decimal128(DecimalArrowToVariantBuilder::new(array, options, 
*scale))
             }
             DataType::Decimal256(_, scale) => {
                 Decimal256(Decimal256ArrowToVariantBuilder::new(array, 
options, *scale))
@@ -320,26 +317,28 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>(
 // worth the trouble, tho, because it makes for some pretty bulky and unwieldy 
macro expansions.
 macro_rules! define_row_builder {
     (
-        struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path 
)?>
+        struct $name:ident<$lifetime:lifetime $(, $generic:ident $( : 
$bound:path )? )*>
         $( where $where_path:path: $where_bound:path $(,)? )?
-        $({ $($field:ident: $field_type:ty),+ $(,)? })?,
+        $({ $( $field:ident: $field_type:ty ),+ $(,)? })?,
         |$array_param:ident| -> $array_type:ty { $init_expr:expr }
-        $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr)?
+        $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr 
)?
     ) => {
-        pub(crate) struct $name<$lifetime $(, $generic: $bound )?>
+        pub(crate) struct $name<$lifetime $(, $generic: $( $bound )? )*>
         $( where $where_path: $where_bound )?
         {
             array: &$lifetime $array_type,
             $( $( $field: $field_type, )+ )?
+            _phantom: std::marker::PhantomData<($( $generic, )*)>, // capture 
all type params
         }
 
-        impl<$lifetime $(, $generic: $bound+ )?> $name<$lifetime $(, 
$generic)?>
+        impl<$lifetime $(, $generic: $( $bound )? )*> $name<$lifetime $(, 
$generic)*>
         $( where $where_path: $where_bound )?
         {
-            pub(crate) fn new($array_param: &$lifetime dyn Array $(, $( 
$field: $field_type ),+ )?) -> Self {
+            pub(crate) fn new($array_param: &$lifetime dyn Array $( $(, 
$field: $field_type )+ )?) -> Self {
                 Self {
                     array: $init_expr,
                     $( $( $field, )+ )?
+                    _phantom: std::marker::PhantomData,
                 }
             }
 
@@ -401,32 +400,18 @@ define_row_builder!(
 );
 
 define_row_builder!(
-    struct Decimal32ArrowToVariantBuilder<'a> {
-        options: &'a CastOptions,
-        scale: i8,
-    },
-    |array| -> arrow::array::Decimal32Array { array.as_primitive() },
-    |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i32, 
VariantDecimal4) }
-);
-
-define_row_builder!(
-    struct Decimal64ArrowToVariantBuilder<'a> {
-        options: &'a CastOptions,
-        scale: i8,
-    },
-    |array| -> arrow::array::Decimal64Array { array.as_primitive() },
-    |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i64, 
VariantDecimal8) }
-);
-
-define_row_builder!(
-    struct Decimal128ArrowToVariantBuilder<'a> {
+    struct DecimalArrowToVariantBuilder<'a, A: DecimalType, V>
+    where
+        V: VariantDecimalType<Native = A::Native>,
+    {
         options: &'a CastOptions,
         scale: i8,
     },
-    |array| -> arrow::array::Decimal128Array { array.as_primitive() },
-    |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i128, 
VariantDecimal16) }
+    |array| -> PrimitiveArray<A> { array.as_primitive() },
+    |value| -> Option<_> { V::try_new_with_signed_scale(value, *scale).ok() }
 );
 
+// Decimal256 needs a two-stage conversion via i128
 define_row_builder!(
     struct Decimal256ArrowToVariantBuilder<'a> {
         options: &'a CastOptions,
@@ -434,10 +419,8 @@ define_row_builder!(
     },
     |array| -> arrow::array::Decimal256Array { array.as_primitive() },
     |value| -> Option<_> {
-        // Decimal256 needs special handling - convert to i128 if possible
-        value.to_i128().and_then(|i128_val| {
-            decimal_to_variant_decimal!(i128_val, scale, i128, 
VariantDecimal16)
-        })
+        let value = value.to_i128();
+        value.and_then(|v| VariantDecimal16::try_new_with_signed_scale(v, 
*scale).ok())
     }
 );
 
@@ -911,6 +894,7 @@ mod tests {
     use super::*;
     use crate::{VariantArray, VariantArrayBuilder};
     use arrow::array::{ArrayRef, BooleanArray, Int32Array, StringArray};
+    use arrow::datatypes::Int32Type;
     use std::sync::Arc;
 
     /// Builds a VariantArray from an Arrow array using the row builder.
diff --git a/parquet-variant-compute/src/type_conversion.rs 
b/parquet-variant-compute/src/type_conversion.rs
index 5afebb1bfa..7851ccc735 100644
--- a/parquet-variant-compute/src/type_conversion.rs
+++ b/parquet-variant-compute/src/type_conversion.rs
@@ -150,22 +150,3 @@ macro_rules! primitive_conversion_single_value {
     }};
 }
 pub(crate) use primitive_conversion_single_value;
-
-/// Convert a decimal value to a `VariantDecimal`
-macro_rules! decimal_to_variant_decimal {
-    ($v:ident, $scale:expr, $value_type:ty, $variant_type:ty) => {{
-        let (v, scale) = if *$scale < 0 {
-            // For negative scale, we need to multiply the value by 10^|scale|
-            // For example: 123 with scale -2 becomes 12300 with scale 0
-            let multiplier = <$value_type>::pow(10, (-*$scale) as u32);
-            (<$value_type>::checked_mul($v, multiplier), 0u8)
-        } else {
-            (Some($v), *$scale as u8)
-        };
-
-        // Return an Option to allow callers to decide whether to error 
(strict)
-        // or append null (non-strict) on conversion failure
-        v.and_then(|v| <$variant_type>::try_new(v, scale).ok())
-    }};
-}
-pub(crate) use decimal_to_variant_decimal;
diff --git a/parquet-variant-compute/src/unshred_variant.rs 
b/parquet-variant-compute/src/unshred_variant.rs
index 64eaa46ed0..c20bb69790 100644
--- a/parquet-variant-compute/src/unshred_variant.rs
+++ b/parquet-variant-compute/src/unshred_variant.rs
@@ -35,8 +35,9 @@ use chrono::{DateTime, Utc};
 use indexmap::IndexMap;
 use parquet_variant::{
     ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, 
VariantDecimal8,
-    VariantDecimal16, VariantMetadata,
+    VariantDecimal16, VariantDecimalType, VariantMetadata,
 };
+use std::marker::PhantomData;
 use uuid::Uuid;
 
 /// Removes all (nested) typed_value columns from a VariantArray by converting 
them back to binary
@@ -95,9 +96,9 @@ enum UnshredVariantRowBuilder<'a> {
     PrimitiveInt64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Int64Type>>),
     PrimitiveFloat32(UnshredPrimitiveRowBuilder<'a, 
PrimitiveArray<Float32Type>>),
     PrimitiveFloat64(UnshredPrimitiveRowBuilder<'a, 
PrimitiveArray<Float64Type>>),
-    Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Spec>),
-    Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Spec>),
-    Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Spec>),
+    Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Type, VariantDecimal4>),
+    Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Type, VariantDecimal8>),
+    Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Type, VariantDecimal16>),
     PrimitiveDate32(UnshredPrimitiveRowBuilder<'a, 
PrimitiveArray<Date32Type>>),
     PrimitiveTime64(UnshredPrimitiveRowBuilder<'a, 
PrimitiveArray<Time64MicrosecondType>>),
     TimestampMicrosecond(TimestampUnshredRowBuilder<'a, 
TimestampMicrosecondType>),
@@ -185,25 +186,23 @@ impl<'a> UnshredVariantRowBuilder<'a> {
             DataType::Int64 => primitive_builder!(PrimitiveInt64, 
as_primitive),
             DataType::Float32 => primitive_builder!(PrimitiveFloat32, 
as_primitive),
             DataType::Float64 => primitive_builder!(PrimitiveFloat64, 
as_primitive),
-            DataType::Decimal32(_, scale) => 
Self::Decimal32(DecimalUnshredRowBuilder::new(
-                value,
-                typed_value.as_primitive(),
-                *scale,
-            )),
-            DataType::Decimal64(_, scale) => 
Self::Decimal64(DecimalUnshredRowBuilder::new(
-                value,
-                typed_value.as_primitive(),
-                *scale,
-            )),
-            DataType::Decimal128(_, scale) => 
Self::Decimal128(DecimalUnshredRowBuilder::new(
-                value,
-                typed_value.as_primitive(),
-                *scale,
-            )),
-            DataType::Decimal256(_, _) => {
-                return Err(ArrowError::InvalidArgumentError(
-                    "Decimal256 is not a valid variant shredding 
type".to_string(),
-                ));
+            DataType::Decimal32(p, s) if 
VariantDecimal4::is_valid_precision_and_scale(p, s) => {
+                Self::Decimal32(DecimalUnshredRowBuilder::new(value, 
typed_value, *s as _))
+            }
+            DataType::Decimal64(p, s) if 
VariantDecimal8::is_valid_precision_and_scale(p, s) => {
+                Self::Decimal64(DecimalUnshredRowBuilder::new(value, 
typed_value, *s as _))
+            }
+            DataType::Decimal128(p, s) if 
VariantDecimal16::is_valid_precision_and_scale(p, s) => {
+                Self::Decimal128(DecimalUnshredRowBuilder::new(value, 
typed_value, *s as _))
+            }
+            DataType::Decimal32(_, _)
+            | DataType::Decimal64(_, _)
+            | DataType::Decimal128(_, _)
+            | DataType::Decimal256(_, _) => {
+                return Err(ArrowError::InvalidArgumentError(format!(
+                    "{} is not a valid variant shredding type",
+                    typed_value.data_type()
+                )));
             }
             DataType::Date32 => primitive_builder!(PrimitiveDate32, 
as_primitive),
             DataType::Time64(TimeUnit::Microsecond) => {
@@ -214,20 +213,12 @@ impl<'a> UnshredVariantRowBuilder<'a> {
                     "Time64({time_unit}) is not a valid variant shredding 
type",
                 )));
             }
-            DataType::Timestamp(TimeUnit::Microsecond, timezone) => {
-                Self::TimestampMicrosecond(TimestampUnshredRowBuilder::new(
-                    value,
-                    typed_value.as_primitive(),
-                    timezone.is_some(),
-                ))
-            }
-            DataType::Timestamp(TimeUnit::Nanosecond, timezone) => {
-                Self::TimestampNanosecond(TimestampUnshredRowBuilder::new(
-                    value,
-                    typed_value.as_primitive(),
-                    timezone.is_some(),
-                ))
-            }
+            DataType::Timestamp(TimeUnit::Microsecond, timezone) => 
Self::TimestampMicrosecond(
+                TimestampUnshredRowBuilder::new(value, typed_value, 
timezone.is_some()),
+            ),
+            DataType::Timestamp(TimeUnit::Nanosecond, timezone) => 
Self::TimestampNanosecond(
+                TimestampUnshredRowBuilder::new(value, typed_value, 
timezone.is_some()),
+            ),
             DataType::Timestamp(time_unit, _) => {
                 return Err(ArrowError::InvalidArgumentError(format!(
                     "Timestamp({time_unit}) is not a valid variant shredding 
type",
@@ -474,12 +465,12 @@ struct TimestampUnshredRowBuilder<'a, T: TimestampType> {
 impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> {
     fn new(
         value: Option<&'a BinaryViewArray>,
-        typed_value: &'a PrimitiveArray<T>,
+        typed_value: &'a dyn Array,
         has_timezone: bool,
     ) -> Self {
         Self {
             value,
-            typed_value,
+            typed_value: typed_value.as_primitive(),
             has_timezone,
         }
     }
@@ -504,78 +495,27 @@ impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, 
T> {
     }
 }
 
-/// Trait to unify decimal unshredding across Decimal32/64/128 types
-trait DecimalSpec {
-    type Arrow: ArrowPrimitiveType + DecimalType;
-
-    fn into_variant(
-        raw: <Self::Arrow as ArrowPrimitiveType>::Native,
-        scale: i8,
-    ) -> Result<Variant<'static, 'static>>;
-}
-
-/// Spec for Decimal32 -> VariantDecimal4
-struct Decimal32Spec;
-
-impl DecimalSpec for Decimal32Spec {
-    type Arrow = Decimal32Type;
-
-    fn into_variant(raw: i32, scale: i8) -> Result<Variant<'static, 'static>> {
-        let scale =
-            u8::try_from(scale).map_err(|e| 
ArrowError::InvalidArgumentError(e.to_string()))?;
-        let value = VariantDecimal4::try_new(raw, scale)
-            .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
-        Ok(value.into())
-    }
-}
-
-/// Spec for Decimal64 -> VariantDecimal8
-struct Decimal64Spec;
-
-impl DecimalSpec for Decimal64Spec {
-    type Arrow = Decimal64Type;
-
-    fn into_variant(raw: i64, scale: i8) -> Result<Variant<'static, 'static>> {
-        let scale =
-            u8::try_from(scale).map_err(|e| 
ArrowError::InvalidArgumentError(e.to_string()))?;
-        let value = VariantDecimal8::try_new(raw, scale)
-            .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
-        Ok(value.into())
-    }
-}
-
-/// Spec for Decimal128 -> VariantDecimal16
-struct Decimal128Spec;
-
-impl DecimalSpec for Decimal128Spec {
-    type Arrow = Decimal128Type;
-
-    fn into_variant(raw: i128, scale: i8) -> Result<Variant<'static, 'static>> 
{
-        let scale =
-            u8::try_from(scale).map_err(|e| 
ArrowError::InvalidArgumentError(e.to_string()))?;
-        let value = VariantDecimal16::try_new(raw, scale)
-            .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
-        Ok(value.into())
-    }
-}
-
-/// Generic builder for decimal unshredding that caches scale
-struct DecimalUnshredRowBuilder<'a, S: DecimalSpec> {
+/// Generic builder for decimal unshredding
+struct DecimalUnshredRowBuilder<'a, A: DecimalType, V>
+where
+    V: VariantDecimalType<Native = A::Native>,
+{
     value: Option<&'a BinaryViewArray>,
-    typed_value: &'a PrimitiveArray<S::Arrow>,
+    typed_value: &'a PrimitiveArray<A>,
     scale: i8,
+    _phantom: PhantomData<V>,
 }
 
-impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> {
-    fn new(
-        value: Option<&'a BinaryViewArray>,
-        typed_value: &'a PrimitiveArray<S::Arrow>,
-        scale: i8,
-    ) -> Self {
+impl<'a, A: DecimalType, V> DecimalUnshredRowBuilder<'a, A, V>
+where
+    V: VariantDecimalType<Native = A::Native>,
+{
+    fn new(value: Option<&'a BinaryViewArray>, typed_value: &'a dyn Array, 
scale: i8) -> Self {
         Self {
             value,
-            typed_value,
+            typed_value: typed_value.as_primitive(),
             scale,
+            _phantom: PhantomData,
         }
     }
 
@@ -588,7 +528,7 @@ impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> {
         handle_unshredded_case!(self, builder, metadata, index, false);
 
         let raw = self.typed_value.value(index);
-        let variant = S::into_variant(raw, self.scale)?;
+        let variant = V::try_new_with_signed_scale(raw, self.scale)?;
         builder.append_value(variant);
         Ok(())
     }
diff --git a/parquet-variant-compute/src/variant_array.rs 
b/parquet-variant-compute/src/variant_array.rs
index 5686d102d3..522c5a7546 100644
--- a/parquet-variant-compute/src/variant_array.rs
+++ b/parquet-variant-compute/src/variant_array.rs
@@ -26,13 +26,11 @@ use arrow::datatypes::{
     TimestampMicrosecondType, TimestampNanosecondType,
 };
 use arrow_schema::extension::ExtensionType;
-use arrow_schema::{
-    ArrowError, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, 
DECIMAL128_MAX_PRECISION,
-    DataType, Field, FieldRef, Fields, TimeUnit,
-};
+use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, TimeUnit};
 use chrono::DateTime;
-use parquet_variant::Uuid;
-use parquet_variant::Variant;
+use parquet_variant::{
+    Uuid, Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16, 
VariantDecimalType as _,
+};
 
 use std::borrow::Cow;
 use std::sync::Arc;
@@ -937,18 +935,6 @@ fn cast_to_binary_view_arrays(array: &dyn Array) -> 
Result<ArrayRef, ArrowError>
     cast(array, new_type.as_ref())
 }
 
-/// Validates whether a given arrow decimal is a valid variant decimal
-///
-/// NOTE: By a strict reading of the "decimal table" in the [shredding spec], 
each decimal type
-/// should have a width-dependent lower bound on precision as well as an upper 
bound (i.e. Decimal16
-/// with precision 5 is invalid because Decimal4 "covers" it). But the variant 
shredding integration
-/// tests specifically expect such cases to succeed, so we only enforce the 
upper bound here.
-///
-/// [shredding spec]: 
https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types
-fn is_valid_variant_decimal(p: &u8, s: &i8, max_precision: u8) -> bool {
-    (1..=max_precision).contains(p) && (0..=*p as i8).contains(s)
-}
-
 /// Recursively visits a data type, ensuring that it only contains data types 
that can legally
 /// appear in a (possibly shredded) variant array. It also replaces Binary 
fields with BinaryView,
 /// since that's what comes back from the parquet reader and what the variant 
code expects to find.
@@ -984,16 +970,16 @@ fn canonicalize_and_verify_data_type(
         // NOTE: arrow-parquet reads widens 32- and 64-bit decimals to 
128-bit, but the variant spec
         // requires using the narrowest decimal type for a given precision. 
Fix those up first.
         Decimal64(p, s) | Decimal128(p, s)
-            if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) =>
+            if VariantDecimal4::is_valid_precision_and_scale(p, s) =>
         {
             Cow::Owned(Decimal32(*p, *s))
         }
-        Decimal128(p, s) if is_valid_variant_decimal(p, s, 
DECIMAL64_MAX_PRECISION) => {
+        Decimal128(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, 
s) => {
             Cow::Owned(Decimal64(*p, *s))
         }
-        Decimal32(p, s) if is_valid_variant_decimal(p, s, 
DECIMAL32_MAX_PRECISION) => borrow!(),
-        Decimal64(p, s) if is_valid_variant_decimal(p, s, 
DECIMAL64_MAX_PRECISION) => borrow!(),
-        Decimal128(p, s) if is_valid_variant_decimal(p, s, 
DECIMAL128_MAX_PRECISION) => borrow!(),
+        Decimal32(p, s) if VariantDecimal4::is_valid_precision_and_scale(p, s) 
=> borrow!(),
+        Decimal64(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, s) 
=> borrow!(),
+        Decimal128(p, s) if VariantDecimal16::is_valid_precision_and_scale(p, 
s) => borrow!(),
         Decimal32(..) | Decimal64(..) | Decimal128(..) | Decimal256(..) => 
fail!(),
 
         // Only micro and nano timestamps are allowed
diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs
index aa3eb51ed3..819c20d554 100644
--- a/parquet-variant/src/variant.rs
+++ b/parquet-variant/src/variant.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-pub use self::decimal::{VariantDecimal4, VariantDecimal8, VariantDecimal16};
+pub use self::decimal::{VariantDecimal4, VariantDecimal8, VariantDecimal16, 
VariantDecimalType};
 pub use self::list::VariantList;
 pub use self::metadata::{EMPTY_VARIANT_METADATA, EMPTY_VARIANT_METADATA_BYTES, 
VariantMetadata};
 pub use self::object::VariantObject;
diff --git a/parquet-variant/src/variant/decimal.rs 
b/parquet-variant/src/variant/decimal.rs
index b0b7d36ed1..c7849a381a 100644
--- a/parquet-variant/src/variant/decimal.rs
+++ b/parquet-variant/src/variant/decimal.rs
@@ -17,52 +17,188 @@
 use arrow_schema::ArrowError;
 use std::fmt;
 
-// All decimal types use the same try_new implementation
-macro_rules! decimal_try_new {
-    ($integer:ident, $scale:ident) => {{
-        // Validate that scale doesn't exceed precision
-        if $scale > Self::MAX_PRECISION {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "Scale {} is larger than max precision {}",
-                $scale,
-                Self::MAX_PRECISION,
-            )));
-        }
+/// Trait for variant decimal types, enabling generic code across Decimal4/8/16
+///
+/// This trait provides a common interface for the three variant decimal types,
+/// allowing generic functions and data structures to work with any decimal 
width.
+/// It is modeled after Arrow's `DecimalType` trait but adapted for variant 
semantics.
+///
+/// # Example
+///
+/// ```
+/// # use parquet_variant::{VariantDecimal4, VariantDecimal8, 
VariantDecimalType};
+/// #
+/// fn extract_scale<D: VariantDecimalType>(decimal: D) -> u8 {
+///     decimal.scale()
+/// }
+///
+/// let dec4 = VariantDecimal4::try_new(12345, 2).unwrap();
+/// let dec8 = VariantDecimal8::try_new(67890, 3).unwrap();
+///
+/// assert_eq!(extract_scale(dec4), 2);
+/// assert_eq!(extract_scale(dec8), 3);
+/// ```
+pub trait VariantDecimalType: Into<super::Variant<'static, 'static>> {
+    /// The underlying signed integer type (i32, i64, or i128)
+    type Native;
 
-        // Validate that the integer value fits within the precision
-        if $integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "{} is wider than max precision {}",
-                $integer,
-                Self::MAX_PRECISION
-            )));
-        }
+    /// Maximum number of significant digits this decimal type can represent 
(9, 18, or 38)
+    const MAX_PRECISION: u8;
+    /// The largest positive unscaled value that fits in 
[`Self::MAX_PRECISION`] digits.
+    const MAX_UNSCALED_VALUE: Self::Native;
 
-        Ok(Self { $integer, $scale })
-    }};
+    /// True if the given precision and scale are valid for this variant 
decimal type.
+    ///
+    /// NOTE: By a strict reading of the "decimal table" in the [variant 
spec], one might conclude that
+    /// each decimal type has both lower and upper bounds on precision (i.e. 
Decimal16 with precision 5
+    /// is invalid because Decimal4 "covers" it). But the variant shredding 
integration tests
+    /// specifically expect such cases to succeed, so we only enforce the 
upper bound here.
+    ///
+    /// [shredding spec]: 
https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types
+    ///
+    /// # Example
+    /// ```
+    /// # use parquet_variant::{VariantDecimal4, VariantDecimalType};
+    /// #
+    /// assert!(VariantDecimal4::is_valid_precision_and_scale(&5, &2));
+    /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&10, &2)); // 
too wide
+    /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&5, &-1)); // 
negative scale
+    /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&5, &7)); // 
scale too big
+    /// ```
+    fn is_valid_precision_and_scale(precision: &u8, scale: &i8) -> bool {
+        (1..=Self::MAX_PRECISION).contains(precision) && (0..=*precision as 
i8).contains(scale)
+    }
+
+    /// Creates a new decimal value from the given unscaled integer and scale, 
failing if the
+    /// integer's width, or the requested scale, exceeds `MAX_PRECISION`.
+    ///
+    /// NOTE: For compatibility with arrow decimal types, negative scale is 
allowed as long
+    /// as the rescaled value fits in the available precision.
+    ///
+    /// # Example
+    ///
+    /// ```
+    /// # use parquet_variant::{VariantDecimal4, VariantDecimalType};
+    /// #
+    /// // Valid: 123.45 (5 digits, scale 2)
+    /// let d = VariantDecimal4::try_new(12345, 2).unwrap();
+    /// assert_eq!(d.integer(), 12345);
+    /// assert_eq!(d.scale(), 2);
+    ///
+    /// VariantDecimal4::try_new(123, 10).expect_err("scale exceeds 
MAX_PRECISION");
+    /// VariantDecimal4::try_new(1234567890, 10).expect_err("value's width 
exceeds MAX_PRECISION");
+    /// ```
+    fn try_new(integer: Self::Native, scale: u8) -> Result<Self, ArrowError>;
+
+    /// Attempts to convert an unscaled arrow decimal value to the indicated 
variant decimal type.
+    ///
+    /// Unlike [`Self::try_new`], this function accepts a signed scale, and 
attempts to rescale
+    /// negative-scale values to their equivalent (larger) scale-0 values. For 
example, a decimal
+    /// value of 123 with scale -2 becomes 12300 with scale 0.
+    ///
+    /// Fails if rescaling fails, or for any of the reasons [`Self::try_new`] 
could fail.
+    fn try_new_with_signed_scale(integer: Self::Native, scale: i8) -> 
Result<Self, ArrowError>;
+
+    /// Returns the unscaled integer value
+    fn integer(&self) -> Self::Native;
+
+    /// Returns the scale (number of digits after the decimal point)
+    fn scale(&self) -> u8;
 }
 
-// All decimal values format the same way, using integer arithmetic to avoid 
floating point precision loss
-macro_rules! format_decimal {
-    ($f:expr, $integer:expr, $scale:expr, $int_type:ty) => {{
-        let integer = if $scale == 0 {
-            $integer
-        } else {
-            let divisor = <$int_type>::pow(10, $scale as u32);
-            let remainder = $integer % divisor;
-            if remainder != 0 {
-                // Track the sign explicitly, in case the quotient is zero
-                let sign = if $integer < 0 { "-" } else { "" };
-                // Format an unsigned remainder with leading zeros and strip 
(unnecessary) trailing zeros.
-                let remainder = format!("{:0width$}", remainder.abs(), width = 
$scale as usize);
-                let remainder = remainder.trim_end_matches('0');
-                let quotient = $integer / divisor;
-                return write!($f, "{}{}.{}", sign, quotient.abs(), remainder);
+/// Implements the complete variant decimal type: methods, Display, and 
VariantDecimalType trait
+macro_rules! impl_variant_decimal {
+    ($struct_name:ident, $native:ty) => {
+        impl $struct_name {
+            /// Attempts to create a new instance of this decimal type, 
failing if the value is too
+            /// wide or the scale is too large.
+            pub fn try_new(integer: $native, scale: u8) -> Result<Self, 
ArrowError> {
+                let max_precision = Self::MAX_PRECISION;
+                if scale > max_precision {
+                    return Err(ArrowError::InvalidArgumentError(format!(
+                        "Scale {scale} is larger than max precision 
{max_precision}",
+                    )));
+                }
+                if 
!(-Self::MAX_UNSCALED_VALUE..=Self::MAX_UNSCALED_VALUE).contains(&integer) {
+                    return Err(ArrowError::InvalidArgumentError(format!(
+                        "{integer} is wider than max precision 
{max_precision}",
+                    )));
+                }
+
+                Ok(Self { integer, scale })
+            }
+
+            /// Returns the unscaled integer value of the decimal.
+            ///
+            /// For example, if the decimal is `123.45`, this will return 
`12345`.
+            pub fn integer(&self) -> $native {
+                self.integer
+            }
+
+            /// Returns the scale of the decimal (how many digits after the 
decimal point).
+            ///
+            /// For example, if the decimal is `123.45`, this will return `2`.
+            pub fn scale(&self) -> u8 {
+                self.scale
+            }
+        }
+
+        impl VariantDecimalType for $struct_name {
+            type Native = $native;
+            const MAX_PRECISION: u8 = Self::MAX_PRECISION;
+            const MAX_UNSCALED_VALUE: $native = <$native>::pow(10, 
Self::MAX_PRECISION as u32) - 1;
+
+            fn try_new(integer: $native, scale: u8) -> Result<Self, 
ArrowError> {
+                Self::try_new(integer, scale)
+            }
+
+            fn try_new_with_signed_scale(integer: $native, scale: i8) -> 
Result<Self, ArrowError> {
+                let (integer, scale) = if scale < 0 {
+                    let multiplier = <$native>::checked_pow(10, -scale as u32);
+                    let Some(rescaled) = multiplier.and_then(|m| 
integer.checked_mul(m)) else {
+                        return Err(ArrowError::InvalidArgumentError(format!(
+                            "Overflow when rescaling {integer} with scale 
{scale}"
+                        )));
+                    };
+                    (rescaled, 0u8)
+                } else {
+                    (integer, scale as u8)
+                };
+                Self::try_new(integer, scale)
             }
-            $integer / divisor
-        };
-        write!($f, "{}", integer)
-    }};
+
+            fn integer(&self) -> $native {
+                self.integer()
+            }
+
+            fn scale(&self) -> u8 {
+                self.scale()
+            }
+        }
+
+        impl fmt::Display for $struct_name {
+            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+                let integer = if self.scale == 0 {
+                    self.integer
+                } else {
+                    let divisor = <$native>::pow(10, self.scale as u32);
+                    let remainder = self.integer % divisor;
+                    if remainder != 0 {
+                        // Track the sign explicitly, in case the quotient is 
zero
+                        let sign = if self.integer < 0 { "-" } else { "" };
+                        // Format an unsigned remainder with leading zeros and 
strip trailing zeros
+                        let remainder =
+                            format!("{:0width$}", remainder.abs(), width = 
self.scale as usize);
+                        let remainder = remainder.trim_end_matches('0');
+                        let quotient = (self.integer / divisor).abs();
+                        return write!(f, "{sign}{quotient}.{remainder}");
+                    }
+                    self.integer / divisor
+                };
+                write!(f, "{integer}")
+            }
+        }
+    };
 }
 
 /// Represents a 4-byte decimal value in the Variant format.
@@ -86,33 +222,11 @@ pub struct VariantDecimal4 {
 }
 
 impl VariantDecimal4 {
-    pub(crate) const MAX_PRECISION: u8 = 9;
-    pub(crate) const MAX_UNSCALED_VALUE: u32 = u32::pow(10, 
Self::MAX_PRECISION as u32) - 1;
-
-    pub fn try_new(integer: i32, scale: u8) -> Result<Self, ArrowError> {
-        decimal_try_new!(integer, scale)
-    }
-
-    /// Returns the underlying value of the decimal.
-    ///
-    /// For example, if the decimal is `123.4567`, this will return `1234567`.
-    pub fn integer(&self) -> i32 {
-        self.integer
-    }
-
-    /// Returns the scale of the decimal (how many digits after the decimal 
point).
-    ///
-    /// For example, if the decimal is `123.4567`, this will return `4`.
-    pub fn scale(&self) -> u8 {
-        self.scale
-    }
+    /// Maximum number of significant digits (9 for 4-byte decimals)
+    pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL32_MAX_PRECISION;
 }
 
-impl fmt::Display for VariantDecimal4 {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        format_decimal!(f, self.integer, self.scale, i32)
-    }
-}
+impl_variant_decimal!(VariantDecimal4, i32);
 
 /// Represents an 8-byte decimal value in the Variant format.
 ///
@@ -136,33 +250,11 @@ pub struct VariantDecimal8 {
 }
 
 impl VariantDecimal8 {
-    pub(crate) const MAX_PRECISION: u8 = 18;
-    pub(crate) const MAX_UNSCALED_VALUE: u64 = u64::pow(10, 
Self::MAX_PRECISION as u32) - 1;
-
-    pub fn try_new(integer: i64, scale: u8) -> Result<Self, ArrowError> {
-        decimal_try_new!(integer, scale)
-    }
-
-    /// Returns the underlying value of the decimal.
-    ///
-    /// For example, if the decimal is `123456.78`, this will return 
`12345678`.
-    pub fn integer(&self) -> i64 {
-        self.integer
-    }
-
-    /// Returns the scale of the decimal (how many digits after the decimal 
point).
-    ///
-    /// For example, if the decimal is `123456.78`, this will return `2`.
-    pub fn scale(&self) -> u8 {
-        self.scale
-    }
+    /// Maximum number of significant digits (18 for 8-byte decimals)
+    pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL64_MAX_PRECISION;
 }
 
-impl fmt::Display for VariantDecimal8 {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        format_decimal!(f, self.integer, self.scale, i64)
-    }
-}
+impl_variant_decimal!(VariantDecimal8, i64);
 
 /// Represents an 16-byte decimal value in the Variant format.
 ///
@@ -186,33 +278,11 @@ pub struct VariantDecimal16 {
 }
 
 impl VariantDecimal16 {
-    const MAX_PRECISION: u8 = 38;
-    const MAX_UNSCALED_VALUE: u128 = u128::pow(10, Self::MAX_PRECISION as u32) 
- 1;
-
-    pub fn try_new(integer: i128, scale: u8) -> Result<Self, ArrowError> {
-        decimal_try_new!(integer, scale)
-    }
-
-    /// Returns the underlying value of the decimal.
-    ///
-    /// For example, if the decimal is `12345678901234567.890`, this will 
return `12345678901234567890`.
-    pub fn integer(&self) -> i128 {
-        self.integer
-    }
-
-    /// Returns the scale of the decimal (how many digits after the decimal 
point).
-    ///
-    /// For example, if the decimal is `12345678901234567.890`, this will 
return `3`.
-    pub fn scale(&self) -> u8 {
-        self.scale
-    }
+    /// Maximum number of significant digits (38 for 16-byte decimals)
+    pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL128_MAX_PRECISION;
 }
 
-impl fmt::Display for VariantDecimal16 {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        format_decimal!(f, self.integer, self.scale, i128)
-    }
-}
+impl_variant_decimal!(VariantDecimal16, i128);
 
 // Infallible conversion from a narrower decimal type to a wider one
 macro_rules! impl_from_decimal_for_decimal {


Reply via email to