andygrove commented on code in PR #3017:
URL: https://github.com/apache/datafusion-comet/pull/3017#discussion_r2665950351
##########
native/spark-expr/src/conversion_funcs/cast.rs:
##########
@@ -1954,100 +1964,247 @@ fn cast_string_to_int_with_range_check(
}
}
+// Returns (start, end) indices after trimming whitespace
+fn trim_whitespace(bytes: &[u8]) -> (usize, usize) {
+ let mut start = 0;
+ let mut end = bytes.len();
+
+ while start < end && bytes[start].is_ascii_whitespace() {
+ start += 1;
+ }
+ while end > start && bytes[end - 1].is_ascii_whitespace() {
+ end -= 1;
+ }
+
+ (start, end)
+}
+
+// Parses sign and returns (is_negative, start_idx after sign)
+// Returns None if invalid (e.g., just "+" or "-")
+fn parse_sign(trimmed_bytes: &[u8]) -> Option<(bool, usize)> {
+ let len = trimmed_bytes.len();
+ if len == 0 {
+ return None;
+ }
+
+ let first_char = trimmed_bytes[0];
+ let negative = first_char == b'-';
+
+ if negative || first_char == b'+' {
+ if len == 1 {
+ return None;
+ }
+ Some((negative, 1))
+ } else {
+ Some((false, 0))
+ }
+}
+
/// Equivalent to
/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper,
boolean allowDecimal)
/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper,
boolean allowDecimal)
-fn do_cast_string_to_int<
- T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
->(
+fn do_parse_string_to_int_legacy<T: Integer + CheckedSub + CheckedNeg +
From<u8> + Copy>(
str: &str,
- eval_mode: EvalMode,
- type_name: &str,
min_value: T,
) -> SparkResult<Option<T>> {
- let trimmed_str = str.trim();
- if trimmed_str.is_empty() {
- return none_or_err(eval_mode, type_name, str);
+ let bytes = str.as_bytes();
+ let (start, end) = trim_whitespace(bytes);
+
+ if start == end {
+ return Ok(None);
}
- let len = trimmed_str.len();
+ let trimmed_bytes = &bytes[start..end];
+
+ let (negative, idx) = match parse_sign(trimmed_bytes) {
+ Some(result) => result,
+ None => return Ok(None),
+ };
+
let mut result: T = T::zero();
- let mut negative = false;
- let radix = T::from(10);
+
+ let radix = T::from(10_u8);
let stop_value = min_value / radix;
let mut parse_sign_and_digits = true;
- for (i, ch) in trimmed_str.char_indices() {
+ for &ch in &trimmed_bytes[idx..] {
if parse_sign_and_digits {
- if i == 0 {
- negative = ch == '-';
- let positive = ch == '+';
- if negative || positive {
- if i + 1 == len {
- // input string is just "+" or "-"
- return none_or_err(eval_mode, type_name, str);
- }
- // consume this char
- continue;
- }
+ if ch == b'.' {
+ // truncate decimal in legacy mode
+ parse_sign_and_digits = false;
+ continue;
}
- if ch == '.' {
- if eval_mode == EvalMode::Legacy {
- // truncate decimal in legacy mode
- parse_sign_and_digits = false;
- continue;
- } else {
- return none_or_err(eval_mode, type_name, str);
- }
+ if !ch.is_ascii_digit() {
+ return Ok(None);
}
- let digit = if ch.is_ascii_digit() {
- (ch as u32) - ('0' as u32)
- } else {
- return none_or_err(eval_mode, type_name, str);
- };
+ let digit: T = T::from(ch - b'0');
- // We are going to process the new digit and accumulate the
result. However, before
- // doing this, if the result is already smaller than the
- // stopValue(Integer.MIN_VALUE / radix), then result * 10 will
definitely be
- // smaller than minValue, and we can stop
if result < stop_value {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
}
-
- // Since the previous result is greater than or equal to
stopValue(Integer.MIN_VALUE /
- // radix), we can just use `result > 0` to check overflow. If
result
- // overflows, we should stop
let v = result * radix;
- let digit = (digit as i32).into();
match v.checked_sub(&digit) {
Some(x) if x <= T::zero() => result = x,
_ => {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
}
}
} else {
- // make sure fractional digits are valid digits but ignore them
+ // in legacy mode we still process chars after the dot and make
sure the chars are digits
if !ch.is_ascii_digit() {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
+ }
+ }
+ }
+
+ if !negative {
+ if let Some(neg) = result.checked_neg() {
+ if neg < T::zero() {
+ return Ok(None);
+ }
+ result = neg;
+ } else {
+ return Ok(None);
+ }
+ }
+
+ Ok(Some(result))
+}
+
+fn do_parse_string_to_int_ansi<T: Integer + CheckedSub + CheckedNeg + From<u8>
+ Copy>(
+ str: &str,
+ type_name: &str,
+ min_value: T,
+) -> SparkResult<Option<T>> {
+ let bytes = str.as_bytes();
+ let (start, end) = trim_whitespace(bytes);
+
+ if start == end {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ let trimmed_bytes = &bytes[start..end];
+
+ let (negative, idx) = match parse_sign(trimmed_bytes) {
+ Some(result) => result,
+ None => return Err(invalid_value(str, "STRING", type_name)),
+ };
+
+ let mut result: T = T::zero();
+
+ let radix = T::from(10_u8);
+ let stop_value = min_value / radix;
+
+ for &ch in &trimmed_bytes[idx..] {
+ if ch == b'.' {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+
+ if !ch.is_ascii_digit() {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+
+ let digit: T = T::from(ch - b'0');
+
+ if result < stop_value {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ let v = result * radix;
+ match v.checked_sub(&digit) {
+ Some(x) if x <= T::zero() => result = x,
+ _ => {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ }
+ }
+
+ if !negative {
+ if let Some(neg) = result.checked_neg() {
+ if neg < T::zero() {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ result = neg;
+ } else {
+ return Err(invalid_value(str, "STRING", type_name));
+ }
+ }
+
+ Ok(Some(result))
+}
+
+fn do_parse_string_to_int_try<T: Integer + CheckedSub + CheckedNeg + From<u8>
+ Copy>(
+ str: &str,
+ min_value: T,
+) -> SparkResult<Option<T>> {
+ let bytes = str.as_bytes();
+ let (start, end) = trim_whitespace(bytes);
+
+ if start == end {
+ return Ok(None);
+ }
+ let trimmed_bytes = &bytes[start..end];
+
+ let (negative, idx) = match parse_sign(trimmed_bytes) {
+ Some(result) => result,
+ None => return Ok(None),
+ };
+
+ let mut result: T = T::zero();
+
+ let radix = T::from(10_u8);
+ let stop_value = min_value / radix;
+
+ // we don't have to go beyond decimal point in try eval mode - early
return NULL
+ for &ch in &trimmed_bytes[idx..] {
+ if ch == b'.' {
+ return Ok(None);
+ }
+
+ if !ch.is_ascii_digit() {
+ return Ok(None);
+ }
+
+ let digit: T = T::from(ch - b'0');
+
+ if result < stop_value {
+ return Ok(None);
+ }
+ let v = result * radix;
+ match v.checked_sub(&digit) {
+ Some(x) if x <= T::zero() => result = x,
+ _ => {
+ return Ok(None);
}
}
}
if !negative {
if let Some(neg) = result.checked_neg() {
if neg < T::zero() {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
}
result = neg;
} else {
- return none_or_err(eval_mode, type_name, str);
+ return Ok(None);
}
}
Ok(Some(result))
}
+fn do_cast_string_to_int<T: Integer + CheckedSub + CheckedNeg + From<u8> +
Copy>(
+ str: &str,
+ eval_mode: EvalMode,
+ type_name: &str,
+ min_value: T,
+) -> SparkResult<Option<T>> {
+ match eval_mode {
Review Comment:
👍
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]