This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 89846a8edf [Variant] `rescale_decimal` followup (#8655)
89846a8edf is described below
commit 89846a8edf700e70c14d38829413ce822a8f0ee1
Author: Liam Bao <[email protected]>
AuthorDate: Wed Oct 22 10:45:52 2025 -0400
[Variant] `rescale_decimal` followup (#8655)
# Which issue does this PR close?
- Followup of #8552.
# Rationale for this change
Code cleanup and optimization
# What changes are included in this PR?
Addressed the post-comments in #8552 and refactor/optimize the method
`rescale_decimal`
# Are these changes tested?
Covered by existing tests
# Are there any user-facing changes?
No
---
arrow-array/src/types.rs | 14 +++--
parquet-variant-compute/src/type_conversion.rs | 85 ++++++++++++--------------
2 files changed, 47 insertions(+), 52 deletions(-)
diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index fda19242ee..fcd2d6958f 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -1324,7 +1324,7 @@ pub trait DecimalType:
/// Maximum no of digits after the decimal point (note the scale can be
negative)
const MAX_SCALE: i8;
/// The maximum value for each precision in `0..=MAX_PRECISION`: [0, 9,
99, ...]
- const MAX_FOR_EACH_PRECISION: &[Self::Native];
+ const MAX_FOR_EACH_PRECISION: &'static [Self::Native];
/// fn to create its [`DataType`]
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType;
/// Default values for [`DataType`]
@@ -1395,7 +1395,8 @@ impl DecimalType for Decimal32Type {
const BYTE_LENGTH: usize = 4;
const MAX_PRECISION: u8 = DECIMAL32_MAX_PRECISION;
const MAX_SCALE: i8 = DECIMAL32_MAX_SCALE;
- const MAX_FOR_EACH_PRECISION: &[i32] =
&arrow_data::decimal::MAX_DECIMAL32_FOR_EACH_PRECISION;
+ const MAX_FOR_EACH_PRECISION: &'static [i32] =
+ &arrow_data::decimal::MAX_DECIMAL32_FOR_EACH_PRECISION;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal32;
const DEFAULT_TYPE: DataType =
DataType::Decimal32(DECIMAL32_MAX_PRECISION, DECIMAL32_DEFAULT_SCALE);
@@ -1430,7 +1431,8 @@ impl DecimalType for Decimal64Type {
const BYTE_LENGTH: usize = 8;
const MAX_PRECISION: u8 = DECIMAL64_MAX_PRECISION;
const MAX_SCALE: i8 = DECIMAL64_MAX_SCALE;
- const MAX_FOR_EACH_PRECISION: &[i64] =
&arrow_data::decimal::MAX_DECIMAL64_FOR_EACH_PRECISION;
+ const MAX_FOR_EACH_PRECISION: &'static [i64] =
+ &arrow_data::decimal::MAX_DECIMAL64_FOR_EACH_PRECISION;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal64;
const DEFAULT_TYPE: DataType =
DataType::Decimal64(DECIMAL64_MAX_PRECISION, DECIMAL64_DEFAULT_SCALE);
@@ -1465,7 +1467,8 @@ impl DecimalType for Decimal128Type {
const BYTE_LENGTH: usize = 16;
const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION;
const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE;
- const MAX_FOR_EACH_PRECISION: &[i128] =
&arrow_data::decimal::MAX_DECIMAL128_FOR_EACH_PRECISION;
+ const MAX_FOR_EACH_PRECISION: &'static [i128] =
+ &arrow_data::decimal::MAX_DECIMAL128_FOR_EACH_PRECISION;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
@@ -1500,7 +1503,8 @@ impl DecimalType for Decimal256Type {
const BYTE_LENGTH: usize = 32;
const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION;
const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE;
- const MAX_FOR_EACH_PRECISION: &[i256] =
&arrow_data::decimal::MAX_DECIMAL256_FOR_EACH_PRECISION;
+ const MAX_FOR_EACH_PRECISION: &'static [i256] =
+ &arrow_data::decimal::MAX_DECIMAL256_FOR_EACH_PRECISION;
const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
diff --git a/parquet-variant-compute/src/type_conversion.rs
b/parquet-variant-compute/src/type_conversion.rs
index 83ffc8f08d..38ca66289b 100644
--- a/parquet-variant-compute/src/type_conversion.rs
+++ b/parquet-variant-compute/src/type_conversion.rs
@@ -189,7 +189,7 @@ where
/// Rescale a decimal from (input_precision, input_scale) to
(output_precision, output_scale)
/// and return the scaled value if it fits the output precision. Similar to
the implementation in
/// decimal.rs in arrow-cast.
-pub(crate) fn rescale_decimal<I, O>(
+pub(crate) fn rescale_decimal<I: DecimalType, O: DecimalType>(
value: I::Native,
input_precision: u8,
input_scale: i8,
@@ -197,32 +197,41 @@ pub(crate) fn rescale_decimal<I, O>(
output_scale: i8,
) -> Option<O::Native>
where
- I: DecimalType,
- O: DecimalType,
I::Native: DecimalCast,
O::Native: DecimalCast,
{
let delta_scale = output_scale - input_scale;
- // Determine if the cast is infallible based on precision/scale math
- let is_infallible_cast =
- is_infallible_decimal_cast(input_precision, input_scale,
output_precision, output_scale);
+ let (scaled, is_infallible_cast) = if delta_scale >= 0 {
+ // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999,
...).
+ // Adding 1 yields exactly 10^k without computing a power at runtime.
+ // Using the precomputed table avoids pow(10, k) and its
checked/overflow
+ // handling, which is faster and simpler for scaling by 10^delta_scale.
+ let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?;
+ let mul = max.add_wrapping(O::Native::ONE);
- let scaled = if delta_scale == 0 {
- O::Native::from_decimal(value)
- } else if delta_scale > 0 {
- let mul = O::Native::from_decimal(10_i128)
- .and_then(|t| t.pow_checked(delta_scale as u32).ok())?;
- O::Native::from_decimal(value).and_then(|x| x.mul_checked(mul).ok())
+ // if the gain in precision (digits) is greater than the
multiplication due to scaling
+ // every number will fit into the output type
+ // Example: If we are starting with any number of precision 5 [xxxxx],
+ // then an increase of scale by 3 will have the following effect on
the representation:
+ // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output
type
+ // needs to provide at least 8 digits precision
+ let is_infallible_cast = input_precision as i8 + delta_scale <=
output_precision as i8;
+ let value = O::Native::from_decimal(value);
+ let scaled = if is_infallible_cast {
+ Some(value.unwrap().mul_wrapping(mul))
+ } else {
+ value.and_then(|x| x.mul_checked(mul).ok())
+ };
+ (scaled, is_infallible_cast)
} else {
- // delta_scale is guaranteed to be > 0, but may also be larger than
I::MAX_PRECISION. If so, the
- // scale change divides out more digits than the input has precision
and the result of the cast
- // is always zero. For example, if we try to apply delta_scale=10 a
decimal32 value, the largest
- // possible result is 999999999/10000000000 = 0.0999999999, which
rounds to zero. Smaller values
- // (e.g. 1/10000000000) or larger delta_scale (e.g.
999999999/10000000000000) produce even
- // smaller results, which also round to zero. In that case, just
return an array of zeros.
- let delta_scale = delta_scale.unsigned_abs() as usize;
- let Some(max) = I::MAX_FOR_EACH_PRECISION.get(delta_scale) else {
+ // the abs of delta_scale is guaranteed to be > 0, but may also be
larger than I::MAX_PRECISION.
+ // If so, the scale change divides out more digits than the input has
precision and the result
+ // of the cast is always zero. For example, if we try to apply
delta_scale=10 a decimal32 value,
+ // the largest possible result is 999999999/10000000000 =
0.0999999999, which rounds to zero.
+ // Smaller values (e.g. 1/10000000000) or larger delta_scale (e.g.
999999999/10000000000000)
+ // produce even smaller results, which also round to zero. In that
case, just return zero.
+ let Some(max) =
I::MAX_FOR_EACH_PRECISION.get(delta_scale.unsigned_abs() as usize) else {
return Some(O::Native::ZERO);
};
let div = max.add_wrapping(I::Native::ONE);
@@ -239,43 +248,25 @@ where
false if r <= half_neg => d.sub_wrapping(I::Native::ONE),
_ => d,
};
- O::Native::from_decimal(adjusted)
- };
- scaled.filter(|v| is_infallible_cast || O::is_valid_decimal_precision(*v,
output_precision))
-}
-
-/// Returns true if casting from (input_precision, input_scale) to
-/// (output_precision, output_scale) is infallible based on precision/scale
math.
-fn is_infallible_decimal_cast(
- input_precision: u8,
- input_scale: i8,
- output_precision: u8,
- output_scale: i8,
-) -> bool {
- let delta_scale = output_scale - input_scale;
- let input_precision = input_precision as i8;
- let output_precision = output_precision as i8;
- if delta_scale >= 0 {
- // if the gain in precision (digits) is greater than the
multiplication due to scaling
- // every number will fit into the output type
- // Example: If we are starting with any number of precision 5 [xxxxx],
- // then an increase of scale by 3 will have the following effect on
the representation:
- // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output
type
- // needs to provide at least 8 digits precision
- input_precision + delta_scale <= output_precision
- } else {
// if the reduction of the input number through scaling (dividing) is
greater
// than a possible precision loss (plus potential increase via
rounding)
// every input number will fit into the output type
// Example: If we are starting with any number of precision 5 [xxxxx],
// then and decrease the scale by 3 will have the following effect on
the representation:
// [xxxxx] -> [xx] (+ 1 possibly, due to rounding).
- // The rounding may add an additional digit, so for the cast to be
infallible,
+ // The rounding may add a digit, so for the cast to be infallible,
// the output type needs to have at least 3 digits of precision.
// e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100:
// [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be
possible
- input_precision + delta_scale < output_precision
+ let is_infallible_cast = input_precision as i8 + delta_scale <
output_precision as i8;
+ (O::Native::from_decimal(adjusted), is_infallible_cast)
+ };
+
+ if is_infallible_cast {
+ scaled
+ } else {
+ scaled.filter(|v| O::is_valid_decimal_precision(*v, output_precision))
}
}