This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 1f3af8f7f fix: Fix overflow handling when casting float to decimal (#1914) 1f3af8f7f is described below commit 1f3af8f7f300e2253314f51732a0c2d87dabb558 Author: Leung Ming <165622843+leung-m...@users.noreply.github.com> AuthorDate: Sat Jun 28 04:42:07 2025 +0800 fix: Fix overflow handling when casting float to decimal (#1914) --- native/spark-expr/src/conversion_funcs/cast.rs | 83 ++++++++++++++++---------- 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 8c00f545a..4f724bae5 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -31,8 +31,8 @@ use arrow::{ }, compute::{cast_with_options, take, unary, CastOptions}, datatypes::{ - ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type, - TimestampMicrosecondType, + is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type, + Float64Type, Int64Type, TimestampMicrosecondType, }, error::ArrowError, record_batch::RecordBatch, @@ -1287,38 +1287,25 @@ where for i in 0..input.len() { if input.is_null(i) { cast_array.append_null(); - } else { - let input_value = input.value(i).as_(); - let value = (input_value * mul).round().to_i128(); - - match value { - Some(v) => { - if Decimal128Type::validate_decimal_precision(v, precision).is_err() { - if eval_mode == EvalMode::Ansi { - return Err(SparkError::NumericValueOutOfRange { - value: input_value.to_string(), - precision, - scale, - }); - } else { - cast_array.append_null(); - } - } - cast_array.append_value(v); - } - None => { - if eval_mode == EvalMode::Ansi { - return Err(SparkError::NumericValueOutOfRange { - value: input_value.to_string(), - precision, - scale, - }); - } else { - cast_array.append_null(); - } - } + continue; + } + + let input_value = input.value(i).as_(); + if let Some(v) = (input_value * mul).round().to_i128() { + if is_validate_decimal_precision(v, precision) { + cast_array.append_value(v); + continue; } + }; + + if eval_mode == EvalMode::Ansi { + return Err(SparkError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); } + cast_array.append_null(); } let res = Arc::new( @@ -2203,6 +2190,7 @@ mod tests { use arrow::array::StringArray; use arrow::datatypes::TimestampMicrosecondType; use arrow::datatypes::{Field, Fields, TimeUnit}; + use core::f64; use std::str::FromStr; use super::*; @@ -2671,4 +2659,35 @@ mod tests { unreachable!() } } + + #[test] + fn test_cast_float_to_decimal() { + let a: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(42.), + Some(0.5153125), + Some(-42.4242415), + Some(42e-314), + Some(0.), + Some(-4242.424242), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(f64::NAN), + None, + ])); + let b = + cast_floating_point_to_decimal128::<Float64Type>(&a, 8, 6, EvalMode::Legacy).unwrap(); + assert_eq!(b.len(), a.len()); + let casted = b.as_primitive::<Decimal128Type>(); + assert_eq!(casted.value(0), 42000000); + // https://github.com/apache/datafusion-comet/issues/1371 + // assert_eq!(casted.value(1), 515313); + assert_eq!(casted.value(2), -42424242); + assert_eq!(casted.value(3), 0); + assert_eq!(casted.value(4), 0); + assert!(casted.is_null(5)); + assert!(casted.is_null(6)); + assert!(casted.is_null(7)); + assert!(casted.is_null(8)); + assert!(casted.is_null(9)); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org