alamb commented on code in PR #8689:
URL: https://github.com/apache/arrow-rs/pull/8689#discussion_r2462744111


##########
arrow-cast/src/cast/decimal.rs:
##########
@@ -223,24 +248,136 @@ where
         O::Native::from_decimal(adjusted)
     };
 
-    Ok(if is_infallible_cast {
-        // make sure we don't perform calculations that don't make sense w/o 
validation
-        validate_decimal_precision_and_scale::<O>(output_precision, 
output_scale)?;
-        let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the 
result is guaranteed
-        // to fit into the target type
-        array.unary(g)
+    // if the reduction of the input number through scaling (dividing) is 
greater
+    // than a possible precision loss (plus potential increase via rounding)
+    // every input number will fit into the output type
+    // Example: If we are starting with any number of precision 5 [xxxxx],
+    // then and decrease the scale by 3 will have the following effect on the 
representation:
+    // [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
+    // The rounding may add a digit, so the cast to be infallible,
+    // the output type needs to have at least 3 digits of precision.
+    // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
+    // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be 
possible
+    let is_infallible_cast = (input_precision as i8) - delta_scale < 
(output_precision as i8);
+    let f_infallible = is_infallible_cast.then_some(move |x| 
f_fallible(x).unwrap());
+    Some((f_fallible, f_infallible))
+}
+
+/// Apply the rescaler function to the value.
+/// If the rescaler is infallible, use the infallible function.
+/// Otherwise, use the fallible function and validate the precision.
+fn apply_rescaler<I: DecimalType, O: DecimalType>(
+    value: I::Native,
+    output_precision: u8,
+    f: impl Fn(I::Native) -> Option<O::Native>,
+    f_infallible: Option<impl Fn(I::Native) -> O::Native>,
+) -> Option<O::Native>
+where
+    I::Native: DecimalCast,
+    O::Native: DecimalCast,
+{
+    if let Some(f_infallible) = f_infallible {
+        Some(f_infallible(value))
+    } else {
+        f(value).filter(|v| O::is_valid_decimal_precision(*v, 
output_precision))
+    }
+}
+
+/// Rescales a decimal value from `(input_precision, input_scale)` to
+/// `(output_precision, output_scale)` and returns the converted number when 
it fits
+/// within the output precision.
+///
+/// The function first validates that the requested precision and scale are 
supported for
+/// both the source and destination decimal types. It then either upscales 
(multiplying
+/// by an appropriate power of ten) or downscales (dividing with rounding) the 
input value.
+/// When the scaling factor exceeds the precision table of the destination 
type, the value
+/// is treated as an overflow for upscaling, or rounded to zero for 
downscaling (as any
+/// possible result would be zero at the requested scale).
+///
+/// This mirrors the column-oriented helpers of decimal casting but operates 
on a single value
+/// (row-level) instead of an entire array.
+///
+/// Returns `None` if the value cannot be represented with the requested 
precision.
+pub fn rescale_decimal<I: DecimalType, O: DecimalType>(
+    value: I::Native,
+    input_precision: u8,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+) -> Option<O::Native>
+where
+    I::Native: DecimalCast + ArrowNativeTypeOp,
+    O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+    validate_decimal_precision_and_scale::<I>(input_precision, 
input_scale).ok()?;
+    validate_decimal_precision_and_scale::<O>(output_precision, 
output_scale).ok()?;
+
+    if input_scale <= output_scale {
+        let (f, f_infallible) =
+            make_upscaler::<I, O>(input_precision, input_scale, 
output_precision, output_scale)?;
+        apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
+    } else {
+        let Some((f, f_infallible)) =
+            make_downscaler::<I, O>(input_precision, input_scale, 
output_precision, output_scale)
+        else {
+            // Scale reduction exceeds supported precision; result 
mathematically rounds to zero
+            return Some(O::Native::ZERO);
+        };
+        apply_rescaler::<I, O>(value, output_precision, f, f_infallible)
+    }
+}
+
+fn cast_decimal_to_decimal_error<I, O>(
+    output_precision: u8,
+    output_scale: i8,
+) -> impl Fn(<I as ArrowPrimitiveType>::Native) -> ArrowError
+where
+    I: DecimalType,
+    O: DecimalType,
+    I::Native: DecimalCast + ArrowNativeTypeOp,
+    O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+    move |x: I::Native| {
+        ArrowError::CastError(format!(
+            "Cannot cast to {}({}, {}). Overflowing on {:?}",
+            O::PREFIX,
+            output_precision,
+            output_scale,
+            x
+        ))
+    }
+}
+
+fn apply_decimal_cast<I: DecimalType, O: DecimalType>(
+    array: &PrimitiveArray<I>,
+    output_precision: u8,
+    output_scale: i8,
+    f_fallible: impl Fn(I::Native) -> Option<O::Native>,
+    f_infallible: Option<impl Fn(I::Native) -> O::Native>,
+    cast_options: &CastOptions,
+) -> Result<PrimitiveArray<O>, ArrowError>
+where
+    I::Native: DecimalCast + ArrowNativeTypeOp,
+    O::Native: DecimalCast + ArrowNativeTypeOp,
+{
+    let array = if let Some(f_infallible) = f_infallible {

Review Comment:
   this is a very nice formulation now



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to