This is an automated email from the ASF dual-hosted git repository.

ytyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4b9a468cc1 feat: Add 
`ScalarValue::{new_one,new_zero,new_ten,distance}` support for `Decimal128` and 
`Decimal256` (#16831)
4b9a468cc1 is described below

commit 4b9a468cc1949062cf3cd8685ba8ced377fd212e
Author: theirix <thei...@gmail.com>
AuthorDate: Sun Jul 27 05:09:38 2025 +0100

    feat: Add `ScalarValue::{new_one,new_zero,new_ten,distance}` support for 
`Decimal128` and `Decimal256` (#16831)
    
    * Add missing ScalarValue impls for large decimals
    
    Add methods distance, new_zero, new_one, new_ten for Decimal128,
    Decimal256
    
    * Support expr simplication for Decimal256
    
    * Replace lookup table with i128::pow
    
    * Support different scales for Decimal constructors
    
    - Allow to construct one and ten with different scales
    - Add tests for new_one, new_ten
    - Add test for distance
    
    * Revert "Replace lookup table with i128::pow"
    
    This reverts commit ba23e8c40c97088a405a36b8f1e1c84146178b73.
    
    * Use Arrow error directly
---
 datafusion/common/src/scalar/mod.rs                | 301 ++++++++++++++++++++-
 .../optimizer/src/simplify_expressions/utils.rs    |  88 ++++++
 2 files changed, 381 insertions(+), 8 deletions(-)

diff --git a/datafusion/common/src/scalar/mod.rs 
b/datafusion/common/src/scalar/mod.rs
index 62ae19fd5c..1ced4ab825 100644
--- a/datafusion/common/src/scalar/mod.rs
+++ b/datafusion/common/src/scalar/mod.rs
@@ -74,12 +74,13 @@ use arrow::compute::kernels::numeric::{
     add, add_wrapping, div, mul, mul_wrapping, rem, sub, sub_wrapping,
 };
 use arrow::datatypes::{
-    i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, 
DataType,
-    Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
-    IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, 
IntervalMonthDayNanoType,
-    IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
-    TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, 
UInt16Type,
-    UInt32Type, UInt64Type, UInt8Type, UnionFields, UnionMode, 
DECIMAL128_MAX_PRECISION,
+    i256, validate_decimal_precision_and_scale, ArrowDictionaryKeyType, 
ArrowNativeType,
+    ArrowTimestampType, DataType, Date32Type, Decimal128Type, Decimal256Type, 
Field,
+    Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTime,
+    IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, 
IntervalUnit,
+    IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, 
TimestampMillisecondType,
+    TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, 
UInt64Type,
+    UInt8Type, UnionFields, UnionMode, DECIMAL128_MAX_PRECISION,
 };
 use arrow::util::display::{array_value_to_string, ArrayFormatter, 
FormatOptions};
 use cache::{get_or_create_cached_key_array, get_or_create_cached_null_array};
@@ -1516,6 +1517,34 @@ impl ScalarValue {
             DataType::Float16 => 
ScalarValue::Float16(Some(f16::from_f32(1.0))),
             DataType::Float32 => ScalarValue::Float32(Some(1.0)),
             DataType::Float64 => ScalarValue::Float64(Some(1.0)),
+            DataType::Decimal128(precision, scale) => {
+                validate_decimal_precision_and_scale::<Decimal128Type>(
+                    *precision, *scale,
+                )?;
+                if *scale < 0 {
+                    return _internal_err!("Negative scale is not supported");
+                }
+                match i128::from(10).checked_pow(*scale as u32) {
+                    Some(value) => {
+                        ScalarValue::Decimal128(Some(value), *precision, 
*scale)
+                    }
+                    None => return _internal_err!("Unsupported scale {scale}"),
+                }
+            }
+            DataType::Decimal256(precision, scale) => {
+                validate_decimal_precision_and_scale::<Decimal256Type>(
+                    *precision, *scale,
+                )?;
+                if *scale < 0 {
+                    return _internal_err!("Negative scale is not supported");
+                }
+                match i256::from(10).checked_pow(*scale as u32) {
+                    Some(value) => {
+                        ScalarValue::Decimal256(Some(value), *precision, 
*scale)
+                    }
+                    None => return _internal_err!("Unsupported scale {scale}"),
+                }
+            }
             _ => {
                 return _not_impl_err!(
                     "Can't create an one scalar from data_type 
\"{datatype:?}\""
@@ -1534,6 +1563,34 @@ impl ScalarValue {
             DataType::Float16 => 
ScalarValue::Float16(Some(f16::from_f32(-1.0))),
             DataType::Float32 => ScalarValue::Float32(Some(-1.0)),
             DataType::Float64 => ScalarValue::Float64(Some(-1.0)),
+            DataType::Decimal128(precision, scale) => {
+                validate_decimal_precision_and_scale::<Decimal128Type>(
+                    *precision, *scale,
+                )?;
+                if *scale < 0 {
+                    return _internal_err!("Negative scale is not supported");
+                }
+                match i128::from(10).checked_pow(*scale as u32) {
+                    Some(value) => {
+                        ScalarValue::Decimal128(Some(-value), *precision, 
*scale)
+                    }
+                    None => return _internal_err!("Unsupported scale {scale}"),
+                }
+            }
+            DataType::Decimal256(precision, scale) => {
+                validate_decimal_precision_and_scale::<Decimal256Type>(
+                    *precision, *scale,
+                )?;
+                if *scale < 0 {
+                    return _internal_err!("Negative scale is not supported");
+                }
+                match i256::from(10).checked_pow(*scale as u32) {
+                    Some(value) => {
+                        ScalarValue::Decimal256(Some(-value), *precision, 
*scale)
+                    }
+                    None => return _internal_err!("Unsupported scale {scale}"),
+                }
+            }
             _ => {
                 return _not_impl_err!(
                     "Can't create a negative one scalar from data_type 
\"{datatype:?}\""
@@ -1555,6 +1612,38 @@ impl ScalarValue {
             DataType::Float16 => 
ScalarValue::Float16(Some(f16::from_f32(10.0))),
             DataType::Float32 => ScalarValue::Float32(Some(10.0)),
             DataType::Float64 => ScalarValue::Float64(Some(10.0)),
+            DataType::Decimal128(precision, scale) => {
+                if let Err(err) = 
validate_decimal_precision_and_scale::<Decimal128Type>(
+                    *precision, *scale,
+                ) {
+                    return _internal_err!("Invalid precision and scale {err}");
+                }
+                if *scale <= 0 {
+                    return _internal_err!("Negative scale is not supported");
+                }
+                match i128::from(10).checked_pow((*scale + 1) as u32) {
+                    Some(value) => {
+                        ScalarValue::Decimal128(Some(value), *precision, 
*scale)
+                    }
+                    None => return _internal_err!("Unsupported scale {scale}"),
+                }
+            }
+            DataType::Decimal256(precision, scale) => {
+                if let Err(err) = 
validate_decimal_precision_and_scale::<Decimal256Type>(
+                    *precision, *scale,
+                ) {
+                    return _internal_err!("Invalid precision and scale {err}");
+                }
+                if *scale <= 0 {
+                    return _internal_err!("Negative scale is not supported");
+                }
+                match i256::from(10).checked_pow((*scale + 1) as u32) {
+                    Some(value) => {
+                        ScalarValue::Decimal256(Some(value), *precision, 
*scale)
+                    }
+                    None => return _internal_err!("Unsupported scale {scale}"),
+                }
+            }
             _ => {
                 return _not_impl_err!(
                     "Can't create a ten scalar from data_type \"{datatype:?}\""
@@ -1924,6 +2013,26 @@ impl ScalarValue {
             (Self::Float64(Some(l)), Self::Float64(Some(r))) => {
                 Some((l - r).abs().round() as _)
             }
+            (
+                Self::Decimal128(Some(l), lprecision, lscale),
+                Self::Decimal128(Some(r), rprecision, rscale),
+            ) => {
+                if lprecision == rprecision && lscale == rscale {
+                    l.checked_sub(*r)?.checked_abs()?.to_usize()
+                } else {
+                    None
+                }
+            }
+            (
+                Self::Decimal256(Some(l), lprecision, lscale),
+                Self::Decimal256(Some(r), rprecision, rscale),
+            ) => {
+                if lprecision == rprecision && lscale == rscale {
+                    l.checked_sub(*r)?.checked_abs()?.to_usize()
+                } else {
+                    None
+                }
+            }
             _ => None,
         }
     }
@@ -4489,7 +4598,9 @@ mod tests {
     };
     use arrow::buffer::{Buffer, OffsetBuffer};
     use arrow::compute::{is_null, kernels};
-    use arrow::datatypes::{ArrowNumericType, Fields, Float64Type};
+    use arrow::datatypes::{
+        ArrowNumericType, Fields, Float64Type, DECIMAL256_MAX_PRECISION,
+    };
     use arrow::error::ArrowError;
     use arrow::util::pretty::pretty_format_columns;
     use chrono::NaiveDate;
@@ -5225,6 +5336,116 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn test_new_one_decimal128() {
+        assert_eq!(
+            ScalarValue::new_one(&DataType::Decimal128(5, 0)).unwrap(),
+            ScalarValue::Decimal128(Some(1), 5, 0)
+        );
+        assert_eq!(
+            ScalarValue::new_one(&DataType::Decimal128(5, 1)).unwrap(),
+            ScalarValue::Decimal128(Some(10), 5, 1)
+        );
+        assert_eq!(
+            ScalarValue::new_one(&DataType::Decimal128(5, 2)).unwrap(),
+            ScalarValue::Decimal128(Some(100), 5, 2)
+        );
+        // More precision
+        assert_eq!(
+            ScalarValue::new_one(&DataType::Decimal128(7, 2)).unwrap(),
+            ScalarValue::Decimal128(Some(100), 7, 2)
+        );
+        // No negative scale
+        assert!(ScalarValue::new_one(&DataType::Decimal128(5, -1)).is_err());
+        // Invalid combination
+        assert!(ScalarValue::new_one(&DataType::Decimal128(0, 2)).is_err());
+        assert!(ScalarValue::new_one(&DataType::Decimal128(5, 7)).is_err());
+    }
+
+    #[test]
+    fn test_new_one_decimal256() {
+        assert_eq!(
+            ScalarValue::new_one(&DataType::Decimal256(5, 0)).unwrap(),
+            ScalarValue::Decimal256(Some(1.into()), 5, 0)
+        );
+        assert_eq!(
+            ScalarValue::new_one(&DataType::Decimal256(5, 1)).unwrap(),
+            ScalarValue::Decimal256(Some(10.into()), 5, 1)
+        );
+        assert_eq!(
+            ScalarValue::new_one(&DataType::Decimal256(5, 2)).unwrap(),
+            ScalarValue::Decimal256(Some(100.into()), 5, 2)
+        );
+        // More precision
+        assert_eq!(
+            ScalarValue::new_one(&DataType::Decimal256(7, 2)).unwrap(),
+            ScalarValue::Decimal256(Some(100.into()), 7, 2)
+        );
+        // No negative scale
+        assert!(ScalarValue::new_one(&DataType::Decimal256(5, -1)).is_err());
+        // Invalid combination
+        assert!(ScalarValue::new_one(&DataType::Decimal256(0, 2)).is_err());
+        assert!(ScalarValue::new_one(&DataType::Decimal256(5, 7)).is_err());
+    }
+
+    #[test]
+    fn test_new_ten_decimal128() {
+        assert_eq!(
+            ScalarValue::new_ten(&DataType::Decimal128(5, 1)).unwrap(),
+            ScalarValue::Decimal128(Some(100), 5, 1)
+        );
+        assert_eq!(
+            ScalarValue::new_ten(&DataType::Decimal128(5, 2)).unwrap(),
+            ScalarValue::Decimal128(Some(1000), 5, 2)
+        );
+        // More precision
+        assert_eq!(
+            ScalarValue::new_ten(&DataType::Decimal128(7, 2)).unwrap(),
+            ScalarValue::Decimal128(Some(1000), 7, 2)
+        );
+        // No negative or zero scale
+        assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 0)).is_err());
+        assert!(ScalarValue::new_ten(&DataType::Decimal128(5, -1)).is_err());
+        // Invalid combination
+        assert!(ScalarValue::new_ten(&DataType::Decimal128(0, 2)).is_err());
+        assert!(ScalarValue::new_ten(&DataType::Decimal128(5, 7)).is_err());
+    }
+
+    #[test]
+    fn test_new_ten_decimal256() {
+        assert_eq!(
+            ScalarValue::new_ten(&DataType::Decimal256(5, 1)).unwrap(),
+            ScalarValue::Decimal256(Some(100.into()), 5, 1)
+        );
+        assert_eq!(
+            ScalarValue::new_ten(&DataType::Decimal256(5, 2)).unwrap(),
+            ScalarValue::Decimal256(Some(1000.into()), 5, 2)
+        );
+        // More precision
+        assert_eq!(
+            ScalarValue::new_ten(&DataType::Decimal256(7, 2)).unwrap(),
+            ScalarValue::Decimal256(Some(1000.into()), 7, 2)
+        );
+        // No negative or zero scale
+        assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 0)).is_err());
+        assert!(ScalarValue::new_ten(&DataType::Decimal256(5, -1)).is_err());
+        // Invalid combination
+        assert!(ScalarValue::new_ten(&DataType::Decimal256(0, 2)).is_err());
+        assert!(ScalarValue::new_ten(&DataType::Decimal256(5, 7)).is_err());
+    }
+
+    #[test]
+    fn test_new_negative_one_decimal128() {
+        assert_eq!(
+            ScalarValue::new_negative_one(&DataType::Decimal128(5, 
0)).unwrap(),
+            ScalarValue::Decimal128(Some(-1), 5, 0)
+        );
+        assert_eq!(
+            ScalarValue::new_negative_one(&DataType::Decimal128(5, 
2)).unwrap(),
+            ScalarValue::Decimal128(Some(-100), 5, 2)
+        );
+    }
+
     #[test]
     fn test_list_partial_cmp() {
         let a =
@@ -7275,6 +7496,26 @@ mod tests {
                 ScalarValue::Float64(Some(-9.9)),
                 5,
             ),
+            (
+                ScalarValue::Decimal128(Some(10), 1, 0),
+                ScalarValue::Decimal128(Some(5), 1, 0),
+                5,
+            ),
+            (
+                ScalarValue::Decimal128(Some(5), 1, 0),
+                ScalarValue::Decimal128(Some(10), 1, 0),
+                5,
+            ),
+            (
+                ScalarValue::Decimal256(Some(10.into()), 1, 0),
+                ScalarValue::Decimal256(Some(5.into()), 1, 0),
+                5,
+            ),
+            (
+                ScalarValue::Decimal256(Some(5.into()), 1, 0),
+                ScalarValue::Decimal256(Some(10.into()), 1, 0),
+                5,
+            ),
         ];
         for (lhs, rhs, expected) in cases.iter() {
             let distance = lhs.distance(rhs).unwrap();
@@ -7282,6 +7523,24 @@ mod tests {
         }
     }
 
+    #[test]
+    fn test_distance_none() {
+        let cases = [
+            (
+                ScalarValue::Decimal128(Some(i128::MAX), 
DECIMAL128_MAX_PRECISION, 0),
+                ScalarValue::Decimal128(Some(-i128::MAX), 
DECIMAL128_MAX_PRECISION, 0),
+            ),
+            (
+                ScalarValue::Decimal256(Some(i256::MAX), 
DECIMAL256_MAX_PRECISION, 0),
+                ScalarValue::Decimal256(Some(-i256::MAX), 
DECIMAL256_MAX_PRECISION, 0),
+            ),
+        ];
+        for (lhs, rhs) in cases.iter() {
+            let distance = lhs.distance(rhs);
+            assert!(distance.is_none(), "{lhs} vs {rhs}");
+        }
+    }
+
     #[test]
     fn test_scalar_distance_invalid() {
         let cases = [
@@ -7323,7 +7582,33 @@ mod tests {
             (ScalarValue::Date64(Some(0)), ScalarValue::Date64(Some(1))),
             (
                 ScalarValue::Decimal128(Some(123), 5, 5),
-                ScalarValue::Decimal128(Some(120), 5, 5),
+                ScalarValue::Decimal128(Some(120), 5, 3),
+            ),
+            (
+                ScalarValue::Decimal128(Some(123), 5, 5),
+                ScalarValue::Decimal128(Some(120), 3, 5),
+            ),
+            (
+                ScalarValue::Decimal256(Some(123.into()), 5, 5),
+                ScalarValue::Decimal256(Some(120.into()), 3, 5),
+            ),
+            // Distance 2 * 2^50 is larger than usize
+            (
+                ScalarValue::Decimal256(
+                    Some(i256::from_parts(0, 2_i64.pow(50).into())),
+                    1,
+                    0,
+                ),
+                ScalarValue::Decimal256(
+                    Some(i256::from_parts(0, (-(2_i64).pow(50)).into())),
+                    1,
+                    0,
+                ),
+            ),
+            // Distance overflow
+            (
+                ScalarValue::Decimal256(Some(i256::from_parts(0, i128::MAX)), 
1, 0),
+                ScalarValue::Decimal256(Some(i256::from_parts(0, -i128::MAX)), 
1, 0),
             ),
         ];
         for (lhs, rhs) in cases {
diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs 
b/datafusion/optimizer/src/simplify_expressions/utils.rs
index 4df0e125eb..2f7dadceba 100644
--- a/datafusion/optimizer/src/simplify_expressions/utils.rs
+++ b/datafusion/optimizer/src/simplify_expressions/utils.rs
@@ -17,6 +17,7 @@
 
 //! Utility functions for expression simplification
 
+use arrow::datatypes::i256;
 use datafusion_common::{internal_err, Result, ScalarValue};
 use datafusion_expr::{
     expr::{Between, BinaryExpr, InList},
@@ -150,6 +151,11 @@ pub fn is_zero(s: &Expr) -> bool {
         Expr::Literal(ScalarValue::Float32(Some(v)), _) if *v == 0. => true,
         Expr::Literal(ScalarValue::Float64(Some(v)), _) if *v == 0. => true,
         Expr::Literal(ScalarValue::Decimal128(Some(v), _p, _s), _) if *v == 0 
=> true,
+        Expr::Literal(ScalarValue::Decimal256(Some(v), _p, _s), _)
+            if *v == i256::ZERO =>
+        {
+            true
+        }
         _ => false,
     }
 }
@@ -173,6 +179,13 @@ pub fn is_one(s: &Expr) -> bool {
                     .map(|x| x == v)
                     .unwrap_or_default()
         }
+        Expr::Literal(ScalarValue::Decimal256(Some(v), _p, s), _) => {
+            *s >= 0
+                && match i256::from(10).checked_pow(*s as u32) {
+                    Some(res) => res == *v,
+                    None => false,
+                }
+        }
         _ => false,
     }
 }
@@ -365,3 +378,78 @@ pub fn distribute_negation(expr: Expr) -> Expr {
         _ => Expr::Negative(Box::new(expr)),
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::{is_one, is_zero};
+    use arrow::datatypes::i256;
+    use datafusion_common::ScalarValue;
+    use datafusion_expr::lit;
+
+    #[test]
+    fn test_is_zero() {
+        assert!(is_zero(&lit(ScalarValue::Int8(Some(0)))));
+        assert!(is_zero(&lit(ScalarValue::Float32(Some(0.0)))));
+        assert!(is_zero(&lit(ScalarValue::Decimal128(
+            Some(i128::from(0)),
+            9,
+            0
+        ))));
+        assert!(is_zero(&lit(ScalarValue::Decimal128(
+            Some(i128::from(0)),
+            9,
+            5
+        ))));
+        assert!(is_zero(&lit(ScalarValue::Decimal256(
+            Some(i256::ZERO),
+            9,
+            0
+        ))));
+        assert!(is_zero(&lit(ScalarValue::Decimal256(
+            Some(i256::ZERO),
+            9,
+            5
+        ))));
+    }
+
+    #[test]
+    fn test_is_one() {
+        assert!(is_one(&lit(ScalarValue::Int8(Some(1)))));
+        assert!(is_one(&lit(ScalarValue::Float32(Some(1.0)))));
+        assert!(is_one(&lit(ScalarValue::Decimal128(
+            Some(i128::from(1)),
+            9,
+            0
+        ))));
+        assert!(is_one(&lit(ScalarValue::Decimal128(
+            Some(i128::from(10)),
+            9,
+            1
+        ))));
+        assert!(is_one(&lit(ScalarValue::Decimal128(
+            Some(i128::from(100)),
+            9,
+            2
+        ))));
+        assert!(is_one(&lit(ScalarValue::Decimal256(
+            Some(i256::from(1)),
+            9,
+            0
+        ))));
+        assert!(is_one(&lit(ScalarValue::Decimal256(
+            Some(i256::from(10)),
+            9,
+            1
+        ))));
+        assert!(is_one(&lit(ScalarValue::Decimal256(
+            Some(i256::from(100)),
+            9,
+            2
+        ))));
+        assert!(!is_one(&lit(ScalarValue::Decimal256(
+            Some(i256::from(100)),
+            9,
+            -1
+        ))));
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to