liamzwbao commented on code in PR #8552:
URL: https://github.com/apache/arrow-rs/pull/8552#discussion_r2412384093


##########
parquet-variant-compute/src/variant_to_arrow.rs:
##########
@@ -145,62 +165,75 @@ pub(crate) fn 
make_primitive_variant_to_arrow_row_builder<'a>(
 ) -> Result<PrimitiveVariantToArrowRowBuilder<'a>> {
     use PrimitiveVariantToArrowRowBuilder::*;
 
-    let builder = match data_type {
-        DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::Float16 => Float16(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::Float32 => Float32(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        DataType::Float64 => Float64(VariantToPrimitiveArrowRowBuilder::new(
-            cast_options,
-            capacity,
-        )),
-        _ if data_type.is_primitive() => {
-            return Err(ArrowError::NotYetImplemented(format!(
-                "Primitive data_type {data_type:?} not yet implemented"
-            )));
-        }
-        _ => {
-            return Err(ArrowError::InvalidArgumentError(format!(
-                "Not a primitive type: {data_type:?}"
-            )));
-        }
-    };
+    let builder =
+        match data_type {
+            DataType::Int8 => Int8(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::Int16 => Int16(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::Int32 => Int32(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::Int64 => Int64(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::UInt8 => UInt8(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::UInt16 => UInt16(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::UInt32 => UInt32(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::UInt64 => UInt64(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::Float16 => 
Float16(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::Float32 => 
Float32(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::Float64 => 
Float64(VariantToPrimitiveArrowRowBuilder::new(
+                cast_options,
+                capacity,
+            )),
+            DataType::Decimal32(precision, scale) => Decimal32(
+                VariantToDecimalArrowRowBuilder::new(cast_options, capacity, 
*precision, *scale)?,

Review Comment:
   I reviewed the implementation of `cast_decimal_to_decimal<I, O>` in 
`arrow-cast`, and it seems to already handle our cases quite well. Specifically:
   
   1. It checks `is_infallible_cast`, which covers the case 3.
   2. For scale-up (`s1 <= s2`), it first converts `I::Native` to `O::Native` 
and then rescales. For scale-down (`s1 > s2`), it divides and rounds the result 
(`I::Native`) before converting to `O::Native`. This approach gracefully 
handles native-type overflow. The subsequent 
`DecimalType::is_valid_decimal_precision` call ensures precision validation, 
similar to our current `MAX_DECIMAL32_FOR_EACH_PRECISION.get(n2 + s1)` check, 
which effectively covers cases 1 & 2, where `n2 < n1` or `n2 == n1`.
   3. That said, case 1 (`n2 < n1`) might present an optimization opportunity 
since we could skip rescaling. Functionally tho, the results should be the 
same. This could be explored in a follow-up PR.
   
   Given this overlap, instead of duplicating logic, I plan to refactor the 
decimal cast function by extracting the shared core logic into a helper like 
and expose it, then we need a dependency on `arrow-cast` tho:
   ```rs
   fn rescale_decimal<I, O>(
       integer: I::Native,
       input_precision: u8,
       input_scale: i8,
       output_precision: u8,
       output_scale: i8,
   ) -> Option<O::Native>
   where
       I: DecimalType,
       O: DecimalType,
       I::Native: DecimalCast,
       O::Native: DecimalCast,
   ```
   
   Then, in our case, we can simply wire the type conversions through this 
helper:
   
   ```rs
   fn variant_to_unscaled_decimal32(
       variant: Variant<'_, '_>,
       precision: u8,
       scale: u8,
   ) -> Result<i32> {
       match variant {
           Variant::Decimal4(d) => rescale_decimal::<Decimal32, Decimal32>(
               d.integer(), VariantDecimal4::MAX_PRECISION, d.scale(), 
precision, scale),
           Variant::Decimal8(d) => rescale_decimal::<Decimal64, Decimal32>(
               d.integer(), VariantDecimal8::MAX_PRECISION, d.scale(), 
precision, scale),
           Variant::Decimal16(d) => rescale_decimal::<Decimal128, Decimal32>(
               d.integer(), VariantDecimal16::MAX_PRECISION, d.scale(), 
precision, scale),
           Variant::Int8(i) => rescale_decimal::<Decimal32, Decimal32>(
               i, VariantDecimal4::MAX_PRECISION, 0, precision, scale),
           Variant::Int16(i) => rescale_decimal::<Decimal32, Decimal32>(
               i, VariantDecimal4::MAX_PRECISION, 0, precision, scale),
           Variant::Int32(i) => rescale_decimal::<Decimal32, Decimal32>(
               i, VariantDecimal4::MAX_PRECISION, 0, precision, scale),
           Variant::Int64(i) => rescale_decimal::<Decimal64, Decimal32>(
               i, VariantDecimal8::MAX_PRECISION, 0, precision, scale),
           _ => return Err(... not exact numeric data ...),
       }
   }
   ```
   
   Let me know if you see any potential risks or edge cases I might have 
overlooked.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to