This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e5d826cc9 [Variant] Decimal unshredding support (#8540)
8e5d826cc9 is described below

commit 8e5d826cc9d8fe7cd19c4f657a6c0c4f757f7b62
Author: Ryan Johnson <[email protected]>
AuthorDate: Fri Oct 3 10:22:27 2025 -0600

    [Variant] Decimal unshredding support (#8540)
    
    # Which issue does this PR close?
    
    - Closes https://github.com/apache/arrow-rs/issues/8332
    
    # Rationale for this change
    
    Missing feature
    
    # What changes are included in this PR?
    
    Add decimal unshredding support, which _should_ have been
    straightforward except:
    1. The variant decimal types are not generic and do not implement any
    common trait that lets us generalize the logic easily. I added a custom
    trait in the unshredding module as a workaround, but we should probably
    look at something similar to arrow's `DecimalType` trait for
    `VariantDecimalXX` classes to implement.
    2. The parquet reader seems to have a bug (feature?) that forces 32- and
    64-bit decimal columns to Decimal128 unless the reader specifically
    requests a narrower type. Which causes the variant decimal integration
    tests to fail because they receive `Variant::Decimal16` values when they
    expected `Variant::Decimal4` or `Variant::Decimal8` (the actual values
    are correct). Rather than directly tackle the bug in arrow-parquet
    itself (which has a large blast radius), I updated `VariantArray`
    constructor to cast such columns back to the correct type as needed.
    
    # Are these changes tested?
    
    Yes. The variant decimal integration tests now pass where they used to
    fail.
    
    # Are there any user-facing changes?
    
    No.
---
 parquet-variant-compute/src/unshred_variant.rs | 127 ++++++++++++++++++++++++-
 parquet-variant-compute/src/variant_array.rs   |  27 +++++-
 parquet/tests/variant_integration.rs           |  31 ++----
 3 files changed, 152 insertions(+), 33 deletions(-)

diff --git a/parquet-variant-compute/src/unshred_variant.rs 
b/parquet-variant-compute/src/unshred_variant.rs
index 264ef458d9..64eaa46ed0 100644
--- a/parquet-variant-compute/src/unshred_variant.rs
+++ b/parquet-variant-compute/src/unshred_variant.rs
@@ -25,15 +25,18 @@ use arrow::array::{
 };
 use arrow::buffer::NullBuffer;
 use arrow::datatypes::{
-    ArrowPrimitiveType, DataType, Date32Type, Float32Type, Float64Type, 
Int8Type, Int16Type,
-    Int32Type, Int64Type, Time64MicrosecondType, TimeUnit, 
TimestampMicrosecondType,
-    TimestampNanosecondType,
+    ArrowPrimitiveType, DataType, Date32Type, Decimal32Type, Decimal64Type, 
Decimal128Type,
+    DecimalType, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, 
Int64Type,
+    Time64MicrosecondType, TimeUnit, TimestampMicrosecondType, 
TimestampNanosecondType,
 };
 use arrow::error::{ArrowError, Result};
 use arrow::temporal_conversions::time64us_to_time;
 use chrono::{DateTime, Utc};
 use indexmap::IndexMap;
-use parquet_variant::{ObjectFieldBuilder, Variant, VariantBuilderExt, 
VariantMetadata};
+use parquet_variant::{
+    ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, 
VariantDecimal8,
+    VariantDecimal16, VariantMetadata,
+};
 use uuid::Uuid;
 
 /// Removes all (nested) typed_value columns from a VariantArray by converting 
them back to binary
@@ -92,6 +95,9 @@ enum UnshredVariantRowBuilder<'a> {
     PrimitiveInt64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray<Int64Type>>),
     PrimitiveFloat32(UnshredPrimitiveRowBuilder<'a, 
PrimitiveArray<Float32Type>>),
     PrimitiveFloat64(UnshredPrimitiveRowBuilder<'a, 
PrimitiveArray<Float64Type>>),
+    Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Spec>),
+    Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Spec>),
+    Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Spec>),
     PrimitiveDate32(UnshredPrimitiveRowBuilder<'a, 
PrimitiveArray<Date32Type>>),
     PrimitiveTime64(UnshredPrimitiveRowBuilder<'a, 
PrimitiveArray<Time64MicrosecondType>>),
     TimestampMicrosecond(TimestampUnshredRowBuilder<'a, 
TimestampMicrosecondType>),
@@ -130,6 +136,9 @@ impl<'a> UnshredVariantRowBuilder<'a> {
             Self::PrimitiveInt64(b) => b.append_row(builder, metadata, index),
             Self::PrimitiveFloat32(b) => b.append_row(builder, metadata, 
index),
             Self::PrimitiveFloat64(b) => b.append_row(builder, metadata, 
index),
+            Self::Decimal32(b) => b.append_row(builder, metadata, index),
+            Self::Decimal64(b) => b.append_row(builder, metadata, index),
+            Self::Decimal128(b) => b.append_row(builder, metadata, index),
             Self::PrimitiveDate32(b) => b.append_row(builder, metadata, index),
             Self::PrimitiveTime64(b) => b.append_row(builder, metadata, index),
             Self::TimestampMicrosecond(b) => b.append_row(builder, metadata, 
index),
@@ -176,6 +185,26 @@ impl<'a> UnshredVariantRowBuilder<'a> {
             DataType::Int64 => primitive_builder!(PrimitiveInt64, 
as_primitive),
             DataType::Float32 => primitive_builder!(PrimitiveFloat32, 
as_primitive),
             DataType::Float64 => primitive_builder!(PrimitiveFloat64, 
as_primitive),
+            DataType::Decimal32(_, scale) => 
Self::Decimal32(DecimalUnshredRowBuilder::new(
+                value,
+                typed_value.as_primitive(),
+                *scale,
+            )),
+            DataType::Decimal64(_, scale) => 
Self::Decimal64(DecimalUnshredRowBuilder::new(
+                value,
+                typed_value.as_primitive(),
+                *scale,
+            )),
+            DataType::Decimal128(_, scale) => 
Self::Decimal128(DecimalUnshredRowBuilder::new(
+                value,
+                typed_value.as_primitive(),
+                *scale,
+            )),
+            DataType::Decimal256(_, _) => {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Decimal256 is not a valid variant shredding 
type".to_string(),
+                ));
+            }
             DataType::Date32 => primitive_builder!(PrimitiveDate32, 
as_primitive),
             DataType::Time64(TimeUnit::Microsecond) => {
                 primitive_builder!(PrimitiveTime64, as_primitive)
@@ -475,6 +504,96 @@ impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, 
T> {
     }
 }
 
+/// Trait to unify decimal unshredding across Decimal32/64/128 types
+trait DecimalSpec {
+    type Arrow: ArrowPrimitiveType + DecimalType;
+
+    fn into_variant(
+        raw: <Self::Arrow as ArrowPrimitiveType>::Native,
+        scale: i8,
+    ) -> Result<Variant<'static, 'static>>;
+}
+
+/// Spec for Decimal32 -> VariantDecimal4
+struct Decimal32Spec;
+
+impl DecimalSpec for Decimal32Spec {
+    type Arrow = Decimal32Type;
+
+    fn into_variant(raw: i32, scale: i8) -> Result<Variant<'static, 'static>> {
+        let scale =
+            u8::try_from(scale).map_err(|e| 
ArrowError::InvalidArgumentError(e.to_string()))?;
+        let value = VariantDecimal4::try_new(raw, scale)
+            .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
+        Ok(value.into())
+    }
+}
+
+/// Spec for Decimal64 -> VariantDecimal8
+struct Decimal64Spec;
+
+impl DecimalSpec for Decimal64Spec {
+    type Arrow = Decimal64Type;
+
+    fn into_variant(raw: i64, scale: i8) -> Result<Variant<'static, 'static>> {
+        let scale =
+            u8::try_from(scale).map_err(|e| 
ArrowError::InvalidArgumentError(e.to_string()))?;
+        let value = VariantDecimal8::try_new(raw, scale)
+            .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
+        Ok(value.into())
+    }
+}
+
+/// Spec for Decimal128 -> VariantDecimal16
+struct Decimal128Spec;
+
+impl DecimalSpec for Decimal128Spec {
+    type Arrow = Decimal128Type;
+
+    fn into_variant(raw: i128, scale: i8) -> Result<Variant<'static, 'static>> 
{
+        let scale =
+            u8::try_from(scale).map_err(|e| 
ArrowError::InvalidArgumentError(e.to_string()))?;
+        let value = VariantDecimal16::try_new(raw, scale)
+            .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?;
+        Ok(value.into())
+    }
+}
+
+/// Generic builder for decimal unshredding that caches scale
+struct DecimalUnshredRowBuilder<'a, S: DecimalSpec> {
+    value: Option<&'a BinaryViewArray>,
+    typed_value: &'a PrimitiveArray<S::Arrow>,
+    scale: i8,
+}
+
+impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> {
+    fn new(
+        value: Option<&'a BinaryViewArray>,
+        typed_value: &'a PrimitiveArray<S::Arrow>,
+        scale: i8,
+    ) -> Self {
+        Self {
+            value,
+            typed_value,
+            scale,
+        }
+    }
+
+    fn append_row(
+        &mut self,
+        builder: &mut impl VariantBuilderExt,
+        metadata: &VariantMetadata,
+        index: usize,
+    ) -> Result<()> {
+        handle_unshredded_case!(self, builder, metadata, index, false);
+
+        let raw = self.typed_value.value(index);
+        let variant = S::into_variant(raw, self.scale)?;
+        builder.append_value(variant);
+        Ok(())
+    }
+}
+
 /// Builder for unshredding struct/object types with nested fields
 struct StructUnshredVariantBuilder<'a> {
     value: Option<&'a arrow::array::BinaryViewArray>,
diff --git a/parquet-variant-compute/src/variant_array.rs 
b/parquet-variant-compute/src/variant_array.rs
index 51ed10b3cf..5686d102d3 100644
--- a/parquet-variant-compute/src/variant_array.rs
+++ b/parquet-variant-compute/src/variant_array.rs
@@ -26,7 +26,10 @@ use arrow::datatypes::{
     TimestampMicrosecondType, TimestampNanosecondType,
 };
 use arrow_schema::extension::ExtensionType;
-use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, TimeUnit};
+use arrow_schema::{
+    ArrowError, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, 
DECIMAL128_MAX_PRECISION,
+    DataType, Field, FieldRef, Fields, TimeUnit,
+};
 use chrono::DateTime;
 use parquet_variant::Uuid;
 use parquet_variant::Variant;
@@ -926,6 +929,11 @@ fn typed_value_to_variant<'a>(
 /// So cast them to get the right type.
 fn cast_to_binary_view_arrays(array: &dyn Array) -> Result<ArrayRef, 
ArrowError> {
     let new_type = canonicalize_and_verify_data_type(array.data_type())?;
+    if let Cow::Borrowed(_) = new_type {
+        if let Some(array) = array.as_struct_opt() {
+            return Ok(Arc::new(array.clone())); // bypass the unnecessary cast
+        }
+    }
     cast(array, new_type.as_ref())
 }
 
@@ -972,9 +980,20 @@ fn canonicalize_and_verify_data_type(
         UInt8 | UInt16 | UInt32 | UInt64 | Float16 => fail!(),
 
         // Most decimal types are allowed, with restrictions on precision and 
scale
-        Decimal32(p, s) if is_valid_variant_decimal(p, s, 9) => borrow!(),
-        Decimal64(p, s) if is_valid_variant_decimal(p, s, 18) => borrow!(),
-        Decimal128(p, s) if is_valid_variant_decimal(p, s, 38) => borrow!(),
+        //
+        // NOTE: arrow-parquet reads widens 32- and 64-bit decimals to 
128-bit, but the variant spec
+        // requires using the narrowest decimal type for a given precision. 
Fix those up first.
+        Decimal64(p, s) | Decimal128(p, s)
+            if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) =>
+        {
+            Cow::Owned(Decimal32(*p, *s))
+        }
+        Decimal128(p, s) if is_valid_variant_decimal(p, s, 
DECIMAL64_MAX_PRECISION) => {
+            Cow::Owned(Decimal64(*p, *s))
+        }
+        Decimal32(p, s) if is_valid_variant_decimal(p, s, 
DECIMAL32_MAX_PRECISION) => borrow!(),
+        Decimal64(p, s) if is_valid_variant_decimal(p, s, 
DECIMAL64_MAX_PRECISION) => borrow!(),
+        Decimal128(p, s) if is_valid_variant_decimal(p, s, 
DECIMAL128_MAX_PRECISION) => borrow!(),
         Decimal32(..) | Decimal64(..) | Decimal128(..) | Decimal256(..) => 
fail!(),
 
         // Only micro and nano timestamps are allowed
diff --git a/parquet/tests/variant_integration.rs 
b/parquet/tests/variant_integration.rs
index 98fa04555d..48f23c46b8 100644
--- a/parquet/tests/variant_integration.rs
+++ b/parquet/tests/variant_integration.rs
@@ -86,31 +86,12 @@ variant_test_case!(20);
 variant_test_case!(21);
 variant_test_case!(22);
 variant_test_case!(23);
-// https://github.com/apache/arrow-rs/issues/8332
-variant_test_case!(
-    24,
-    "Unshredding not yet supported for type: Decimal128(9, 4)"
-);
-variant_test_case!(
-    25,
-    "Unshredding not yet supported for type: Decimal128(9, 4)"
-);
-variant_test_case!(
-    26,
-    "Unshredding not yet supported for type: Decimal128(18, 9)"
-);
-variant_test_case!(
-    27,
-    "Unshredding not yet supported for type: Decimal128(18, 9)"
-);
-variant_test_case!(
-    28,
-    "Unshredding not yet supported for type: Decimal128(38, 9)"
-);
-variant_test_case!(
-    29,
-    "Unshredding not yet supported for type: Decimal128(38, 9)"
-);
+variant_test_case!(24);
+variant_test_case!(25);
+variant_test_case!(26);
+variant_test_case!(27);
+variant_test_case!(28);
+variant_test_case!(29);
 variant_test_case!(30);
 variant_test_case!(31);
 variant_test_case!(32);

Reply via email to