martin-g commented on code in PR #2835:
URL: https://github.com/apache/datafusion-comet/pull/2835#discussion_r2601464262


##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -1058,6 +1055,351 @@ fn cast_array(
     Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
 }
 
+fn cast_string_to_decimal(
+    array: &ArrayRef,
+    to_type: &DataType,
+    precision: &u8,
+    scale: &i8,
+    eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+    match to_type {
+        DataType::Decimal128(_, _) => {
+            cast_string_to_decimal128_impl(array, eval_mode, *precision, 
*scale)
+        }
+        DataType::Decimal256(_, _) => {
+            cast_string_to_decimal256_impl(array, eval_mode, *precision, 
*scale)
+        }
+        _ => Err(SparkError::Internal(format!(
+            "Unexpected type in cast_string_to_decimal: {:?}",
+            to_type
+        ))),
+    }
+}
+
+fn cast_string_to_decimal128_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
Decimal128Builder::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i).trim();
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    decimal_builder.append_value(decimal_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            str_value,
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+fn cast_string_to_decimal256_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i).trim();
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    // Convert i128 to i256
+                    let i256_value = i256::from_i128(decimal_value);
+                    decimal_builder.append_value(i256_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            str_value,
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+/// Parse a string to decimal following Spark's behavior
+/// Returns Ok(Some(value)) if successful, Ok(None) if null, Err if invalid in 
ANSI mode
+fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> 
SparkResult<Option<i128>> {
+    if s.is_empty() {
+        return Ok(None);
+    }
+
+    // Handle special values (inf, nan, etc.)
+    let s_lower = s.to_lowercase();

Review Comment:
   You can use `s.eq_ignore_ascii_case("...")` for the checks below.
   This will avoid the allocation done by `.to_lowercase()`.



##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -1058,6 +1055,351 @@ fn cast_array(
     Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
 }
 
+fn cast_string_to_decimal(
+    array: &ArrayRef,
+    to_type: &DataType,
+    precision: &u8,
+    scale: &i8,
+    eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+    match to_type {
+        DataType::Decimal128(_, _) => {
+            cast_string_to_decimal128_impl(array, eval_mode, *precision, 
*scale)
+        }
+        DataType::Decimal256(_, _) => {
+            cast_string_to_decimal256_impl(array, eval_mode, *precision, 
*scale)
+        }
+        _ => Err(SparkError::Internal(format!(
+            "Unexpected type in cast_string_to_decimal: {:?}",
+            to_type
+        ))),
+    }
+}
+
+fn cast_string_to_decimal128_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
Decimal128Builder::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i).trim();
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    decimal_builder.append_value(decimal_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            str_value,
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+fn cast_string_to_decimal256_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i).trim();
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    // Convert i128 to i256
+                    let i256_value = i256::from_i128(decimal_value);
+                    decimal_builder.append_value(i256_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            str_value,
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+/// Parse a string to decimal following Spark's behavior
+/// Returns Ok(Some(value)) if successful, Ok(None) if null, Err if invalid in 
ANSI mode
+fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> 
SparkResult<Option<i128>> {
+    if s.is_empty() {
+        return Ok(None);
+    }
+
+    // Handle special values (inf, nan, etc.)
+    let s_lower = s.to_lowercase();
+    if s_lower == "inf"
+        || s_lower == "+inf"
+        || s_lower == "infinity"
+        || s_lower == "+infinity"
+        || s_lower == "-inf"
+        || s_lower == "-infinity"
+        || s_lower == "nan"
+    {
+        return Ok(None);
+    }
+
+    // Parse the string as a decimal number
+    // Note: We do NOT strip 'D' or 'F' suffixes - let rust's parsing fail 
naturally for invalid input
+    match parse_decimal_str(s) {
+        Ok((mantissa, exponent)) => {
+            // Convert to target scale
+            let target_scale = scale as i32;
+            let scale_adjustment = target_scale - exponent;
+
+            let scaled_value = if scale_adjustment >= 0 {
+                // Need to multiply (increase scale)
+                mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
+            } else {
+                // Need to divide (decrease scale) - use rounding half up
+                let divisor = 10_i128.pow((-scale_adjustment) as u32);
+                let quotient = mantissa / divisor;
+                let remainder = mantissa % divisor;
+
+                // Round half up: if abs(remainder) >= divisor/2, round away 
from zero
+                let half_divisor = divisor / 2;
+                let rounded = if remainder.abs() >= half_divisor {
+                    if mantissa >= 0 {
+                        quotient + 1
+                    } else {
+                        quotient - 1
+                    }
+                } else {
+                    quotient
+                };
+                Some(rounded)
+            };
+
+            match scaled_value {
+                Some(value) => {
+                    // Check if it fits target precision
+                    if is_validate_decimal_precision(value, precision) {
+                        Ok(Some(value))
+                    } else {
+                        // Overflow
+                        Ok(None)
+                    }
+                }
+                None => {
+                    // Overflow during scaling
+                    Ok(None)
+                }
+            }
+        }
+        Err(_) => Ok(None),
+    }
+}
+
+/// Parse a decimal string into mantissa and scale
+/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
+fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
+    let s = s.trim();
+    if s.is_empty() {
+        return Err("Empty string".to_string());
+    }
+
+    // Check if input is scientific notation (e.g., "1.23E-5", "1e10")
+    let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 
'E'].contains(&c)) {
+        let mantissa_part = &s[..e_pos];
+        let exponent_part = &s[e_pos + 1..];
+
+        // Parse exponent part
+        let exp: i32 = exponent_part
+            .parse()
+            .map_err(|_| "Invalid exponent".to_string())?;
+
+        (mantissa_part, exp)
+    } else {
+        (s, 0)
+    };
+
+    let negative = mantissa_str.starts_with('-');
+    let mantissa_str = if negative || mantissa_str.starts_with('+') {
+        &mantissa_str[1..]
+    } else {
+        mantissa_str
+    };
+
+    let split_by_dot: Vec<&str> = mantissa_str.split('.').collect();
+
+    if split_by_dot.len() > 2 {
+        return Err("Multiple decimal points".to_string());
+    }
+
+    let integral_part = split_by_dot[0];
+    let fractional_part = if split_by_dot.len() == 2 {
+        split_by_dot[1]
+    } else {
+        ""
+    };
+
+    // Parse integral part
+    let integral_value: i128 = if integral_part.is_empty() {
+        0

Review Comment:
   Hm. I am not sure about this.
   Parsing "e5" should return an error instead of 0.
   0 is good when parsing ".0" though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to