scovich commented on code in PR #8552:
URL: https://github.com/apache/arrow-rs/pull/8552#discussion_r2415071425


##########
arrow-cast/src/cast/decimal.rs:
##########
@@ -174,55 +307,20 @@ where
     I::Native: DecimalCast + ArrowNativeTypeOp,
     O::Native: DecimalCast + ArrowNativeTypeOp,
 {
-    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, 
output_scale);
-    let delta_scale = input_scale - output_scale;
-    // if the reduction of the input number through scaling (dividing) is 
greater
-    // than a possible precision loss (plus potential increase via rounding)
-    // every input number will fit into the output type
-    // Example: If we are starting with any number of precision 5 [xxxxx],
-    // then and decrease the scale by 3 will have the following effect on the 
representation:
-    // [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
-    // The rounding may add an additional digit, so the cast to be infallible,
-    // the output type needs to have at least 3 digits of precision.
-    // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
-    // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be 
possible
-    let is_infallible_cast = (input_precision as i8) - delta_scale < 
(output_precision as i8);
-
-    let div = I::Native::from_decimal(10_i128)
-        .unwrap()
-        .pow_checked(delta_scale as u32)?;
-
-    let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
-    let half_neg = half.neg_wrapping();
-
-    let f = |x: I::Native| {
-        // div is >= 10 and so this cannot overflow
-        let d = x.div_wrapping(div);
-        let r = x.mod_wrapping(div);
+    // make sure we don't perform calculations that don't make sense w/o 
validation
+    validate_decimal_precision_and_scale::<O>(output_precision, output_scale)?;

Review Comment:
   Why did this move? It used to get called only for infallible casts, now it 
gets called for all casts?



##########
arrow-cast/src/cast/decimal.rs:
##########
@@ -139,6 +145,133 @@ impl DecimalCast for i256 {
     }
 }
 
+/// Build a rescale function from (input_precision, input_scale) to 
(output_precision, output_scale)
+/// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the 
conversion.
+pub fn rescale_decimal<I, O>(

Review Comment:
   This refactor seems a bit "backward" to me, which probably causes the 
benchmark regressions:
   * Original code was dispatching to two methods 
(`convert_to_smaller_scale_decimal` and 
`convert_to_bigger_or_equal_scale_decimal`) from two locations 
(`cast_decimal_to_decimal` and `cast_decimal_to_decimal_same_type`). This 
avoided some branching in the inner cast loop, because the branch on direction 
of scale change is taken outside the loop.
   * New code pushes everything down into this new `rescale_decimal` method, 
which not only requires the introduction of a new `is_infallible_cast` helper 
method, but also leaves the two `convert_to_xxx_scale_decimal` methods with 
virtually identical bodies. At that point we may as well eliminate those 
helpers entirely and avoid the code bloat... but the helpers probably existed 
for a reason (to hoist at least some branches out of the inner loop). 
   * The new code also allocates errors that get downgraded to empty options, 
where the original code upgraded empty options to errors. Arrow errors allocate 
strings, so that's a meaningful difference.
   
   I wonder if we should instead do:
   * rework `convert_to_smaller_scale_decimal` and 
`convert_to_bigger_or_equal_scale_decimal` 
     * no longer take `array` or `cast_options` as input
     * return `Ok((f, is_infallible_cast)` which corresponds to the return type 
<br>`Result<(impl Fn(I::Native) -> Option<O::Native>, bool), ArrowError>`
   * define a new generic `apply_decimal_cast` function helper
     * it takes as input `array`, `cast_options` and the `(impl Fn, bool)` pair 
produced by a `convert_to_xxx_scale_decimal` helper
     * it handles the three ways of applying `f` to an array
   * rework `cast_decimal_to_decimal` and `cast_decimal_to_decimal_same_type` 
to call those functions (see below)
   * `rescale_decimal` would be the single-row equivalent of 
`cast_decimal_to_decimal`, returning `Option<O::Native>`
   * The decimal builder's constructor calls 
`validate_decimal_precision_and_scale` and fails on error, so we don't need to 
validate on a per-row basis.
   
   <details>
   <summary>cast_decimal_to_decimal</summary>
   
   ```rust
   let array: PrimitiveArray<O> = if input_scale > output_scale {
       let (f, is_infallible_cast) = convert_to_smaller_scale_decimal(...)?;
       apply_decimal_cast(array, cast_options, f, is_infallible)?
   } else {
       let (f, is_infallible_cast) = 
convert_to_bigger_or_equal_scale_decimal(...)?;
       apply_decimal_cast(array, cast_options, f, is_infallible)?
   }
   ```
   
   </details>
   
   <details>
   <summary>rescale_decimal</summary>
   
   ```rust
   if input_scale > output_scale {
       let (f, _) = convert_to_smaller_scale_decimal(...)?;
       f(integer)
   } else {
       let (f, _) = convert_to_bigger_or_equal_scale_decimal(...)?;
       f(integer)
   }
   ```
   
   </details>



##########
arrow-cast/src/cast/decimal.rs:
##########
@@ -139,6 +145,133 @@ impl DecimalCast for i256 {
     }
 }
 
+/// Build a rescale function from (input_precision, input_scale) to 
(output_precision, output_scale)
+/// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the 
conversion.
+pub fn rescale_decimal<I, O>(
+    input_precision: u8,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+) -> impl Fn(I::Native) -> Result<O::Native, ArrowError>
+where
+    I: DecimalType,
+    O: DecimalType,
+    I::Native: DecimalCast,
+    O::Native: DecimalCast,
+{
+    let delta_scale = output_scale - input_scale;
+
+    // Determine if the cast is infallible based on precision/scale math
+    let is_infallible_cast =
+        is_infallible_decimal_cast(input_precision, input_scale, 
output_precision, output_scale);
+
+    // Build a single mode once and use a thin closure that calls into it
+    enum RescaleMode<I, O> {
+        SameScale,
+        Up { mul: O },
+        Down { div: I, half: I, half_neg: I },
+        Invalid,
+    }
+
+    let mode = if delta_scale == 0 {
+        RescaleMode::SameScale
+    } else if delta_scale > 0 {
+        match O::Native::from_decimal(10_i128).and_then(|t| 
t.pow_checked(delta_scale as u32).ok())
+        {
+            Some(mul) => RescaleMode::Up { mul },
+            None => RescaleMode::Invalid,
+        }
+    } else {
+        match I::Native::from_decimal(10_i128)
+            .and_then(|t| t.pow_checked(delta_scale.unsigned_abs() as 
u32).ok())
+        {
+            Some(div) => {
+                let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+                let half_neg = half.neg_wrapping();
+                RescaleMode::Down {
+                    div,
+                    half,
+                    half_neg,
+                }
+            }
+            None => RescaleMode::Invalid,
+        }
+    };
+
+    let f = move |x: I::Native| {
+        match &mode {
+            RescaleMode::SameScale => O::Native::from_decimal(x),
+            RescaleMode::Up { mul } => {
+                O::Native::from_decimal(x).and_then(|x| 
x.mul_checked(*mul).ok())
+            }
+            RescaleMode::Down {
+                div,
+                half,
+                half_neg,
+            } => {
+                // div is >= 10 and so this cannot overflow
+                let d = x.div_wrapping(*div);
+                let r = x.mod_wrapping(*div);
+
+                // Round result
+                let adjusted = match x >= I::Native::ZERO {
+                    true if r >= *half => d.add_wrapping(I::Native::ONE),
+                    false if r <= *half_neg => d.sub_wrapping(I::Native::ONE),
+                    _ => d,
+                };
+                O::Native::from_decimal(adjusted)
+            }
+            RescaleMode::Invalid => None,
+        }
+    };
+
+    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, 
output_scale);
+
+    move |x| {
+        if is_infallible_cast {
+            f(x).ok_or_else(|| error(x))
+        } else {
+            f(x).ok_or_else(|| error(x)).and_then(|v| {
+                O::validate_decimal_precision(v, output_precision, 
output_scale).map(|_| v)
+            })
+        }
+    }
+}
+
+/// Returns true if casting from (input_precision, input_scale) to
+/// (output_precision, output_scale) is infallible based on precision/scale 
math.
+fn is_infallible_decimal_cast(
+    input_precision: u8,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+) -> bool {
+    let delta_scale = output_scale - input_scale;
+    let input_precision_i8 = input_precision as i8;
+    let output_precision_i8 = output_precision as i8;
+    if delta_scale >= 0 {
+        // if the gain in precision (digits) is greater than the 
multiplication due to scaling
+        // every number will fit into the output type
+        // Example: If we are starting with any number of precision 5 [xxxxx],
+        // then an increase of scale by 3 will have the following effect on 
the representation:
+        // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output 
type
+        // needs to provide at least 8 digits precision
+        input_precision_i8 + delta_scale <= output_precision_i8
+    } else {
+        // if the reduction of the input number through scaling (dividing) is 
greater
+        // than a possible precision loss (plus potential increase via 
rounding)
+        // every input number will fit into the output type
+        // Example: If we are starting with any number of precision 5 [xxxxx],
+        // then and decrease the scale by 3 will have the following effect on 
the representation:
+        // [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
+        // The rounding may add an additional digit, so the cast to be 
infallible,

Review Comment:
   ```suggestion
           // The rounding may add an additional digit, so for the cast to be 
infallible,
   ```



##########
arrow-cast/src/cast/decimal.rs:
##########
@@ -139,6 +145,133 @@ impl DecimalCast for i256 {
     }
 }
 
+/// Build a rescale function from (input_precision, input_scale) to 
(output_precision, output_scale)
+/// returning a closure `Fn(I::Native) -> Option<O::Native>` that performs the 
conversion.
+pub fn rescale_decimal<I, O>(
+    input_precision: u8,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+) -> impl Fn(I::Native) -> Result<O::Native, ArrowError>
+where
+    I: DecimalType,
+    O: DecimalType,
+    I::Native: DecimalCast,
+    O::Native: DecimalCast,
+{
+    let delta_scale = output_scale - input_scale;
+
+    // Determine if the cast is infallible based on precision/scale math
+    let is_infallible_cast =
+        is_infallible_decimal_cast(input_precision, input_scale, 
output_precision, output_scale);
+
+    // Build a single mode once and use a thin closure that calls into it
+    enum RescaleMode<I, O> {
+        SameScale,
+        Up { mul: O },
+        Down { div: I, half: I, half_neg: I },
+        Invalid,
+    }
+
+    let mode = if delta_scale == 0 {
+        RescaleMode::SameScale
+    } else if delta_scale > 0 {
+        match O::Native::from_decimal(10_i128).and_then(|t| 
t.pow_checked(delta_scale as u32).ok())
+        {
+            Some(mul) => RescaleMode::Up { mul },
+            None => RescaleMode::Invalid,
+        }
+    } else {
+        match I::Native::from_decimal(10_i128)
+            .and_then(|t| t.pow_checked(delta_scale.unsigned_abs() as 
u32).ok())
+        {
+            Some(div) => {
+                let half = div.div_wrapping(I::Native::from_usize(2).unwrap());
+                let half_neg = half.neg_wrapping();
+                RescaleMode::Down {
+                    div,
+                    half,
+                    half_neg,
+                }
+            }
+            None => RescaleMode::Invalid,
+        }
+    };
+
+    let f = move |x: I::Native| {
+        match &mode {
+            RescaleMode::SameScale => O::Native::from_decimal(x),
+            RescaleMode::Up { mul } => {
+                O::Native::from_decimal(x).and_then(|x| 
x.mul_checked(*mul).ok())
+            }
+            RescaleMode::Down {
+                div,
+                half,
+                half_neg,
+            } => {
+                // div is >= 10 and so this cannot overflow
+                let d = x.div_wrapping(*div);
+                let r = x.mod_wrapping(*div);
+
+                // Round result
+                let adjusted = match x >= I::Native::ZERO {
+                    true if r >= *half => d.add_wrapping(I::Native::ONE),
+                    false if r <= *half_neg => d.sub_wrapping(I::Native::ONE),
+                    _ => d,
+                };
+                O::Native::from_decimal(adjusted)
+            }
+            RescaleMode::Invalid => None,
+        }
+    };
+
+    let error = cast_decimal_to_decimal_error::<I, O>(output_precision, 
output_scale);
+
+    move |x| {
+        if is_infallible_cast {
+            f(x).ok_or_else(|| error(x))
+        } else {
+            f(x).ok_or_else(|| error(x)).and_then(|v| {
+                O::validate_decimal_precision(v, output_precision, 
output_scale).map(|_| v)
+            })
+        }
+    }
+}
+
+/// Returns true if casting from (input_precision, input_scale) to
+/// (output_precision, output_scale) is infallible based on precision/scale 
math.
+fn is_infallible_decimal_cast(
+    input_precision: u8,
+    input_scale: i8,
+    output_precision: u8,
+    output_scale: i8,
+) -> bool {
+    let delta_scale = output_scale - input_scale;
+    let input_precision_i8 = input_precision as i8;
+    let output_precision_i8 = output_precision as i8;

Review Comment:
   I'm not sure the suffix is helpful? The original value is never reused.
   ```suggestion
       let input_precision = input_precision as i8;
       let output_precision = output_precision as i8;
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to