This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f3af8f7f fix: Fix overflow handling when casting float to decimal 
(#1914)
1f3af8f7f is described below

commit 1f3af8f7f300e2253314f51732a0c2d87dabb558
Author: Leung Ming <165622843+leung-m...@users.noreply.github.com>
AuthorDate: Sat Jun 28 04:42:07 2025 +0800

    fix: Fix overflow handling when casting float to decimal (#1914)
---
 native/spark-expr/src/conversion_funcs/cast.rs | 83 ++++++++++++++++----------
 1 file changed, 51 insertions(+), 32 deletions(-)

diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index 8c00f545a..4f724bae5 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -31,8 +31,8 @@ use arrow::{
     },
     compute::{cast_with_options, take, unary, CastOptions},
     datatypes::{
-        ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, 
Float64Type, Int64Type,
-        TimestampMicrosecondType,
+        is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, 
Float32Type,
+        Float64Type, Int64Type, TimestampMicrosecondType,
     },
     error::ArrowError,
     record_batch::RecordBatch,
@@ -1287,38 +1287,25 @@ where
     for i in 0..input.len() {
         if input.is_null(i) {
             cast_array.append_null();
-        } else {
-            let input_value = input.value(i).as_();
-            let value = (input_value * mul).round().to_i128();
-
-            match value {
-                Some(v) => {
-                    if Decimal128Type::validate_decimal_precision(v, 
precision).is_err() {
-                        if eval_mode == EvalMode::Ansi {
-                            return Err(SparkError::NumericValueOutOfRange {
-                                value: input_value.to_string(),
-                                precision,
-                                scale,
-                            });
-                        } else {
-                            cast_array.append_null();
-                        }
-                    }
-                    cast_array.append_value(v);
-                }
-                None => {
-                    if eval_mode == EvalMode::Ansi {
-                        return Err(SparkError::NumericValueOutOfRange {
-                            value: input_value.to_string(),
-                            precision,
-                            scale,
-                        });
-                    } else {
-                        cast_array.append_null();
-                    }
-                }
+            continue;
+        }
+
+        let input_value = input.value(i).as_();
+        if let Some(v) = (input_value * mul).round().to_i128() {
+            if is_validate_decimal_precision(v, precision) {
+                cast_array.append_value(v);
+                continue;
             }
+        };
+
+        if eval_mode == EvalMode::Ansi {
+            return Err(SparkError::NumericValueOutOfRange {
+                value: input_value.to_string(),
+                precision,
+                scale,
+            });
         }
+        cast_array.append_null();
     }
 
     let res = Arc::new(
@@ -2203,6 +2190,7 @@ mod tests {
     use arrow::array::StringArray;
     use arrow::datatypes::TimestampMicrosecondType;
     use arrow::datatypes::{Field, Fields, TimeUnit};
+    use core::f64;
     use std::str::FromStr;
 
     use super::*;
@@ -2671,4 +2659,35 @@ mod tests {
             unreachable!()
         }
     }
+
+    #[test]
+    fn test_cast_float_to_decimal() {
+        let a: ArrayRef = Arc::new(Float64Array::from(vec![
+            Some(42.),
+            Some(0.5153125),
+            Some(-42.4242415),
+            Some(42e-314),
+            Some(0.),
+            Some(-4242.424242),
+            Some(f64::INFINITY),
+            Some(f64::NEG_INFINITY),
+            Some(f64::NAN),
+            None,
+        ]));
+        let b =
+            cast_floating_point_to_decimal128::<Float64Type>(&a, 8, 6, 
EvalMode::Legacy).unwrap();
+        assert_eq!(b.len(), a.len());
+        let casted = b.as_primitive::<Decimal128Type>();
+        assert_eq!(casted.value(0), 42000000);
+        // https://github.com/apache/datafusion-comet/issues/1371
+        // assert_eq!(casted.value(1), 515313);
+        assert_eq!(casted.value(2), -42424242);
+        assert_eq!(casted.value(3), 0);
+        assert_eq!(casted.value(4), 0);
+        assert!(casted.is_null(5));
+        assert!(casted.is_null(6));
+        assert!(casted.is_null(7));
+        assert!(casted.is_null(8));
+        assert!(casted.is_null(9));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to