This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new c90713b3b perf: Faster decimal precision overflow checks (#6419)
c90713b3b is described below

commit c90713b3b0947644662b563bde5df70e3995c12e
Author: Andy Grove <[email protected]>
AuthorDate: Sat Sep 21 08:20:42 2024 -0600

    perf: Faster decimal precision overflow checks (#6419)
    
    * add benchmark
    
    * add optimization
    
    * fix
    
    * fix
    
    * cargo fmt
    
    * clippy
    
    * Update arrow-data/src/decimal.rs
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
    
    * optimize to avoid allocating an idx variable
    
    * revert change to public api
    
    * fix error in rustdoc
    
    ---------
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
 arrow-array/Cargo.toml                   |   4 +
 arrow-array/benches/decimal_overflow.rs  |  53 ++++++++++
 arrow-array/src/array/primitive_array.rs |   4 +-
 arrow-array/src/types.rs                 |  16 ++-
 arrow-cast/src/cast/decimal.rs           |  10 +-
 arrow-cast/src/cast/mod.rs               |  14 +--
 arrow-data/src/decimal.rs                | 165 ++++++++++++++++++++++++++-----
 7 files changed, 224 insertions(+), 42 deletions(-)

diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml
index 57b86c192..d993d36b8 100644
--- a/arrow-array/Cargo.toml
+++ b/arrow-array/Cargo.toml
@@ -71,3 +71,7 @@ harness = false
 [[bench]]
 name = "fixed_size_list_array"
 harness = false
+
+[[bench]]
+name = "decimal_overflow"
+harness = false
diff --git a/arrow-array/benches/decimal_overflow.rs 
b/arrow-array/benches/decimal_overflow.rs
new file mode 100644
index 000000000..8f22b4b47
--- /dev/null
+++ b/arrow-array/benches/decimal_overflow.rs
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_array::builder::{Decimal128Builder, Decimal256Builder};
+use arrow_buffer::i256;
+use criterion::*;
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let len = 8192;
+    let mut builder_128 = Decimal128Builder::with_capacity(len);
+    let mut builder_256 = Decimal256Builder::with_capacity(len);
+    for i in 0..len {
+        if i % 10 == 0 {
+            builder_128.append_value(i128::MAX);
+            builder_256.append_value(i256::from_i128(i128::MAX));
+        } else {
+            builder_128.append_value(i as i128);
+            builder_256.append_value(i256::from_i128(i as i128));
+        }
+    }
+    let array_128 = builder_128.finish();
+    let array_256 = builder_256.finish();
+
+    c.bench_function("validate_decimal_precision_128", |b| {
+        b.iter(|| black_box(array_128.validate_decimal_precision(8)));
+    });
+    c.bench_function("null_if_overflow_precision_128", |b| {
+        b.iter(|| black_box(array_128.null_if_overflow_precision(8)));
+    });
+    c.bench_function("validate_decimal_precision_256", |b| {
+        b.iter(|| black_box(array_256.validate_decimal_precision(8)));
+    });
+    c.bench_function("null_if_overflow_precision_256", |b| {
+        b.iter(|| black_box(array_256.null_if_overflow_precision(8)));
+    });
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/arrow-array/src/array/primitive_array.rs 
b/arrow-array/src/array/primitive_array.rs
index 521ef088e..567fa00e7 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -1570,9 +1570,7 @@ impl<T: DecimalType + ArrowPrimitiveType> 
PrimitiveArray<T> {
     /// Validates the Decimal Array, if the value of slot is overflow for the 
specified precision, and
     /// will be casted to Null
     pub fn null_if_overflow_precision(&self, precision: u8) -> Self {
-        self.unary_opt::<_, T>(|v| {
-            (T::validate_decimal_precision(v, precision).is_ok()).then_some(v)
-        })
+        self.unary_opt::<_, T>(|v| T::is_valid_decimal_precision(v, 
precision).then_some(v))
     }
 
     /// Returns [`Self::value`] formatted as a string
diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs
index b39c9c403..92262fc04 100644
--- a/arrow-array/src/types.rs
+++ b/arrow-array/src/types.rs
@@ -24,7 +24,10 @@ use crate::temporal_conversions::as_datetime_with_timezone;
 use crate::timezone::Tz;
 use crate::{ArrowNativeTypeOp, OffsetSizeTrait};
 use arrow_buffer::{i256, Buffer, OffsetBuffer};
-use arrow_data::decimal::{validate_decimal256_precision, 
validate_decimal_precision};
+use arrow_data::decimal::{
+    is_validate_decimal256_precision, is_validate_decimal_precision, 
validate_decimal256_precision,
+    validate_decimal_precision,
+};
 use arrow_data::{validate_binary_view, validate_string_view};
 use arrow_schema::{
     ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, 
DECIMAL128_MAX_SCALE,
@@ -1194,6 +1197,9 @@ pub trait DecimalType:
 
     /// Validates that `value` contains no more than `precision` decimal digits
     fn validate_decimal_precision(value: Self::Native, precision: u8) -> 
Result<(), ArrowError>;
+
+    /// Determines whether `value` contains no more than `precision` decimal 
digits
+    fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool;
 }
 
 /// Validate that `precision` and `scale` are valid for `T`
@@ -1256,6 +1262,10 @@ impl DecimalType for Decimal128Type {
     fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), 
ArrowError> {
         validate_decimal_precision(num, precision)
     }
+
+    fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool {
+        is_validate_decimal_precision(value, precision)
+    }
 }
 
 impl ArrowPrimitiveType for Decimal128Type {
@@ -1286,6 +1296,10 @@ impl DecimalType for Decimal256Type {
     fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), 
ArrowError> {
         validate_decimal256_precision(num, precision)
     }
+
+    fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool {
+        is_validate_decimal256_precision(value, precision)
+    }
 }
 
 impl ArrowPrimitiveType for Decimal256Type {
diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs
index 600f868a3..637cbc417 100644
--- a/arrow-cast/src/cast/decimal.rs
+++ b/arrow-cast/src/cast/decimal.rs
@@ -336,11 +336,7 @@ where
     if cast_options.safe {
         let iter = from.iter().map(|v| {
             v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as 
usize).ok())
-                .and_then(|v| {
-                    T::validate_decimal_precision(v, precision)
-                        .is_ok()
-                        .then_some(v)
-                })
+                .and_then(|v| T::is_valid_decimal_precision(v, 
precision).then_some(v))
         });
         // Benefit:
         //     20% performance improvement
@@ -430,7 +426,7 @@ where
                 (mul * v.as_())
                     .round()
                     .to_i128()
-                    .filter(|v| Decimal128Type::validate_decimal_precision(*v, 
precision).is_ok())
+                    .filter(|v| Decimal128Type::is_valid_decimal_precision(*v, 
precision))
             })
             .with_precision_and_scale(precision, scale)
             .map(|a| Arc::new(a) as ArrayRef)
@@ -473,7 +469,7 @@ where
         array
             .unary_opt::<_, Decimal256Type>(|v| {
                 i256::from_f64((v.as_() * mul).round())
-                    .filter(|v| Decimal256Type::validate_decimal_precision(*v, 
precision).is_ok())
+                    .filter(|v| Decimal256Type::is_valid_decimal_precision(*v, 
precision))
             })
             .with_precision_and_scale(precision, scale)
             .map(|a| Arc::new(a) as ArrayRef)
diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs
index e80d497c8..25ef243e1 100644
--- a/arrow-cast/src/cast/mod.rs
+++ b/arrow-cast/src/cast/mod.rs
@@ -327,9 +327,10 @@ where
     let array = if scale < 0 {
         match cast_options.safe {
             true => array.unary_opt::<_, D>(|v| {
-                v.as_().div_checked(scale_factor).ok().and_then(|v| {
-                    (D::validate_decimal_precision(v, 
precision).is_ok()).then_some(v)
-                })
+                v.as_()
+                    .div_checked(scale_factor)
+                    .ok()
+                    .and_then(|v| (D::is_valid_decimal_precision(v, 
precision)).then_some(v))
             }),
             false => array.try_unary::<_, D, _>(|v| {
                 v.as_()
@@ -340,9 +341,10 @@ where
     } else {
         match cast_options.safe {
             true => array.unary_opt::<_, D>(|v| {
-                v.as_().mul_checked(scale_factor).ok().and_then(|v| {
-                    (D::validate_decimal_precision(v, 
precision).is_ok()).then_some(v)
-                })
+                v.as_()
+                    .mul_checked(scale_factor)
+                    .ok()
+                    .and_then(|v| (D::is_valid_decimal_precision(v, 
precision)).then_some(v))
             }),
             false => array.try_unary::<_, D, _>(|v| {
                 v.as_()
diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs
index 74279bfb9..d9028591a 100644
--- a/arrow-data/src/decimal.rs
+++ b/arrow-data/src/decimal.rs
@@ -23,10 +23,13 @@ pub use arrow_schema::{
     DECIMAL_DEFAULT_SCALE,
 };
 
-// MAX decimal256 value of little-endian format for each precision.
-// Each element is the max value of signed 256-bit integer for the specified 
precision which
-// is encoded to the 32-byte width format of little-endian.
-pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [
+/// MAX decimal256 value of little-endian format for each precision.
+/// Each element is the max value of signed 256-bit integer for the specified 
precision which
+/// is encoded to the 32-byte width format of little-endian.
+/// The first element is unused and is inserted so that we can look up using
+/// precision as the index without the need to subtract 1 first.
+pub(crate) const MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 77] = [
+    i256::from_i128(0_i128), // unused first element
     i256::from_le_bytes([
         9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0,
         0, 0,
@@ -333,10 +336,13 @@ pub(crate) const 
MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [
     ]),
 ];
 
-// MIN decimal256 value of little-endian format for each precision.
-// Each element is the min value of signed 256-bit integer for the specified 
precision which
-// is encoded to the 76-byte width format of little-endian.
-pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [
+/// MIN decimal256 value of little-endian format for each precision.
+/// Each element is the min value of signed 256-bit integer for the specified 
precision which
+/// is encoded to the 76-byte width format of little-endian.
+/// The first element is unused and is inserted so that we can look up using
+/// precision as the index without the need to subtract 1 first.
+pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 77] = [
+    i256::from_i128(0_i128), // unused first element
     i256::from_le_bytes([
         247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 
255, 255, 255, 255,
         255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
@@ -643,8 +649,9 @@ pub(crate) const 
MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [i256; 76] = [
     ]),
 ];
 
-/// `MAX_DECIMAL_FOR_EACH_PRECISION[p]` holds the maximum `i128` value that can
+/// `MAX_DECIMAL_FOR_EACH_PRECISION[p-1]` holds the maximum `i128` value that 
can
 /// be stored in [arrow_schema::DataType::Decimal128] value of precision `p`
+#[allow(dead_code)] // no longer used but is part of our public API
 pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
     9,
     99,
@@ -686,8 +693,9 @@ pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
     99999999999999999999999999999999999999,
 ];
 
-/// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value that can
+/// `MIN_DECIMAL_FOR_EACH_PRECISION[p-1]` holds the minimum `i128` value that 
can
 /// be stored in a [arrow_schema::DataType::Decimal128] value of precision `p`
+#[allow(dead_code)] // no longer used but is part of our public API
 pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
     -9,
     -99,
@@ -729,6 +737,98 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
     -99999999999999999999999999999999999999,
 ];
 
+/// `MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[p]` holds the maximum `i128` 
value that can
+/// be stored in [arrow_schema::DataType::Decimal128] value of precision `p`.
+/// The first element is unused and is inserted so that we can look up using
+/// precision as the index without the need to subtract 1 first.
+pub(crate) const MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED: [i128; 39] = [
+    0, // unused first element
+    9,
+    99,
+    999,
+    9999,
+    99999,
+    999999,
+    9999999,
+    99999999,
+    999999999,
+    9999999999,
+    99999999999,
+    999999999999,
+    9999999999999,
+    99999999999999,
+    999999999999999,
+    9999999999999999,
+    99999999999999999,
+    999999999999999999,
+    9999999999999999999,
+    99999999999999999999,
+    999999999999999999999,
+    9999999999999999999999,
+    99999999999999999999999,
+    999999999999999999999999,
+    9999999999999999999999999,
+    99999999999999999999999999,
+    999999999999999999999999999,
+    9999999999999999999999999999,
+    99999999999999999999999999999,
+    999999999999999999999999999999,
+    9999999999999999999999999999999,
+    99999999999999999999999999999999,
+    999999999999999999999999999999999,
+    9999999999999999999999999999999999,
+    99999999999999999999999999999999999,
+    999999999999999999999999999999999999,
+    9999999999999999999999999999999999999,
+    99999999999999999999999999999999999999,
+];
+
+/// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value that can
+/// be stored in a [arrow_schema::DataType::Decimal128] value of precision `p`.
+/// The first element is unused and is inserted so that we can look up using
+/// precision as the index without the need to subtract 1 first.
+pub(crate) const MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED: [i128; 39] = [
+    0, // unused first element
+    -9,
+    -99,
+    -999,
+    -9999,
+    -99999,
+    -999999,
+    -9999999,
+    -99999999,
+    -999999999,
+    -9999999999,
+    -99999999999,
+    -999999999999,
+    -9999999999999,
+    -99999999999999,
+    -999999999999999,
+    -9999999999999999,
+    -99999999999999999,
+    -999999999999999999,
+    -9999999999999999999,
+    -99999999999999999999,
+    -999999999999999999999,
+    -9999999999999999999999,
+    -99999999999999999999999,
+    -999999999999999999999999,
+    -9999999999999999999999999,
+    -99999999999999999999999999,
+    -999999999999999999999999999,
+    -9999999999999999999999999999,
+    -99999999999999999999999999999,
+    -999999999999999999999999999999,
+    -9999999999999999999999999999999,
+    -99999999999999999999999999999999,
+    -999999999999999999999999999999999,
+    -9999999999999999999999999999999999,
+    -99999999999999999999999999999999999,
+    -999999999999999999999999999999999999,
+    -9999999999999999999999999999999999999,
+    -99999999999999999999999999999999999999,
+];
+
 /// Validates that the specified `i128` value can be properly
 /// interpreted as a Decimal number with precision `precision`
 #[inline]
@@ -738,23 +838,30 @@ pub fn validate_decimal_precision(value: i128, precision: 
u8) -> Result<(), Arro
             "Max precision of a Decimal128 is {DECIMAL128_MAX_PRECISION}, but 
got {precision}",
         )));
     }
-
-    let max = MAX_DECIMAL_FOR_EACH_PRECISION[usize::from(precision) - 1];
-    let min = MIN_DECIMAL_FOR_EACH_PRECISION[usize::from(precision) - 1];
-
-    if value > max {
+    if value > MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize] {
         Err(ArrowError::InvalidArgumentError(format!(
-            "{value} is too large to store in a Decimal128 of precision 
{precision}. Max is {max}"
+            "{value} is too large to store in a Decimal128 of precision 
{precision}. Max is {}",
+            MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize]
         )))
-    } else if value < min {
+    } else if value < MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as 
usize] {
         Err(ArrowError::InvalidArgumentError(format!(
-            "{value} is too small to store in a Decimal128 of precision 
{precision}. Min is {min}"
+            "{value} is too small to store in a Decimal128 of precision 
{precision}. Min is {}",
+            MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as usize]
         )))
     } else {
         Ok(())
     }
 }
 
+/// Determines whether the specified `i128` value can be properly
+/// interpreted as a Decimal number with precision `precision`
+#[inline]
+pub fn is_validate_decimal_precision(value: i128, precision: u8) -> bool {
+    precision <= DECIMAL128_MAX_PRECISION
+        && value >= MIN_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as 
usize]
+        && value <= MAX_DECIMAL_FOR_EACH_PRECISION_ONE_BASED[precision as 
usize]
+}
+
 /// Validates that the specified `i256` of value can be properly
 /// interpreted as a Decimal256 number with precision `precision`
 #[inline]
@@ -764,18 +871,26 @@ pub fn validate_decimal256_precision(value: i256, 
precision: u8) -> Result<(), A
             "Max precision of a Decimal256 is {DECIMAL256_MAX_PRECISION}, but 
got {precision}",
         )));
     }
-    let max = 
MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[usize::from(precision) - 1];
-    let min = 
MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[usize::from(precision) - 1];
-
-    if value > max {
+    if value > MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize] 
{
         Err(ArrowError::InvalidArgumentError(format!(
-            "{value:?} is too large to store in a Decimal256 of precision 
{precision}. Max is {max:?}"
+            "{value:?} is too large to store in a Decimal256 of precision 
{precision}. Max is {:?}",
+            MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize]
         )))
-    } else if value < min {
+    } else if value < MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as 
usize] {
         Err(ArrowError::InvalidArgumentError(format!(
-            "{value:?} is too small to store in a Decimal256 of precision 
{precision}. Min is {min:?}"
+            "{value:?} is too small to store in a Decimal256 of precision 
{precision}. Min is {:?}",
+            MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as usize]
         )))
     } else {
         Ok(())
     }
 }
+
+/// Determines whether the specified `i256` value can be properly
+/// interpreted as a Decimal256 number with precision `precision`
+#[inline]
+pub fn is_validate_decimal256_precision(value: i256, precision: u8) -> bool {
+    precision <= DECIMAL256_MAX_PRECISION
+        && value >= MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as 
usize]
+        && value <= MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION[precision as 
usize]
+}

Reply via email to