andygrove commented on code in PR #2925:
URL: https://github.com/apache/datafusion-comet/pull/2925#discussion_r2625345509


##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -1976,6 +1975,363 @@ fn do_cast_string_to_int<
     Ok(Some(result))
 }
 
+fn cast_string_to_decimal(
+    array: &ArrayRef,
+    to_type: &DataType,
+    precision: &u8,
+    scale: &i8,
+    eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+    match to_type {
+        DataType::Decimal128(_, _) => {
+            cast_string_to_decimal128_impl(array, eval_mode, *precision, 
*scale)
+        }
+        DataType::Decimal256(_, _) => {
+            cast_string_to_decimal256_impl(array, eval_mode, *precision, 
*scale)
+        }
+        _ => Err(SparkError::Internal(format!(
+            "Unexpected type in cast_string_to_decimal: {:?}",
+            to_type
+        ))),
+    }
+}
+
+fn cast_string_to_decimal128_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
Decimal128Builder::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i).trim();
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    decimal_builder.append_value(decimal_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            string_array.value(i),
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+fn cast_string_to_decimal256_impl(
+    array: &ArrayRef,
+    eval_mode: EvalMode,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let string_array = array
+        .as_any()
+        .downcast_ref::<StringArray>()
+        .ok_or_else(|| SparkError::Internal("Expected string 
array".to_string()))?;
+
+    let mut decimal_builder = 
PrimitiveBuilder::<Decimal256Type>::with_capacity(string_array.len());
+
+    for i in 0..string_array.len() {
+        if string_array.is_null(i) {
+            decimal_builder.append_null();
+        } else {
+            let str_value = string_array.value(i).trim();
+            match parse_string_to_decimal(str_value, precision, scale) {
+                Ok(Some(decimal_value)) => {
+                    // Convert i128 to i256
+                    let i256_value = i256::from_i128(decimal_value);
+                    decimal_builder.append_value(i256_value);
+                }
+                Ok(None) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(invalid_value(
+                            str_value,
+                            "STRING",
+                            &format!("DECIMAL({},{})", precision, scale),
+                        ));
+                    }
+                    decimal_builder.append_null();
+                }
+                Err(e) => {
+                    if eval_mode == EvalMode::Ansi {
+                        return Err(e);
+                    }
+                    decimal_builder.append_null();
+                }
+            }
+        }
+    }
+
+    Ok(Arc::new(
+        decimal_builder
+            .with_precision_and_scale(precision, scale)?
+            .finish(),
+    ))
+}
+
+/// Validates if a string is a valid decimal similar to BigDecimal
+fn is_valid_decimal_format(s: &str) -> bool {
+    if s.is_empty() {
+        return false;
+    }
+
+    let bytes = s.as_bytes();
+    let mut idx = 0;
+    let len = bytes.len();
+
+    // Skip leading +/- signs
+    if bytes[idx] == b'+' || bytes[idx] == b'-' {
+        idx += 1;
+        if idx >= len {
+            // Sign only. Fail early
+            return false;
+        }
+    }
+
+    // Check invalid cases like "++", "+-"
+    if bytes[idx] == b'+' || bytes[idx] == b'-' {
+        return false;
+    }
+
+    // Now we need at least one digit either before or after a decimal point
+    let mut has_digit = false;
+    let mut is_decimal_point_seen = false;
+
+    while idx < len {
+        let ch = bytes[idx];
+
+        if ch.is_ascii_digit() {
+            has_digit = true;
+            idx += 1;
+        } else if ch == b'.' {
+            if is_decimal_point_seen {
+                // Multiple decimal points or decimal after exponent
+                return false;
+            }
+            is_decimal_point_seen = true;
+            idx += 1;
+        } else if ch.eq_ignore_ascii_case(&b'e') {
+            if !has_digit {
+                // Exponent without any digits before it
+                return false;
+            }
+            idx += 1;
+            // Exponent part must have optional sign followed by atleast a 
digit
+            if idx >= len {
+                return false;
+            }
+
+            if bytes[idx] == b'+' || bytes[idx] == b'-' {
+                idx += 1;
+                if idx >= len {
+                    return false;
+                }
+            }
+
+            // Must have at least one digit in exponent
+            if !bytes[idx].is_ascii_digit() {
+                return false;
+            }
+
+            // Rest all should only be digits
+            while idx < len {
+                if !bytes[idx].is_ascii_digit() {
+                    return false;
+                }
+                idx += 1;
+            }
+            break;
+        } else {
+            // Invalid character found. Fail fast
+            return false;
+        }
+    }
+    has_digit
+}
+
+/// Parse a string to decimal following Spark's behavior
+fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> 
SparkResult<Option<i128>> {
+    if s.is_empty() {
+        return Ok(None);
+    }
+    // Handle special values (inf, nan, etc.)
+    if s.eq_ignore_ascii_case("inf")
+        || s.eq_ignore_ascii_case("+inf")
+        || s.eq_ignore_ascii_case("infinity")
+        || s.eq_ignore_ascii_case("+infinity")
+        || s.eq_ignore_ascii_case("-inf")
+        || s.eq_ignore_ascii_case("-infinity")
+        || s.eq_ignore_ascii_case("nan")
+    {
+        return Ok(None);
+    }
+
+    if !is_valid_decimal_format(s) {

Review Comment:
   Iterating over the string twice (once for validation and once for parsing) 
seems expensive



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to