andygrove commented on code in PR #307:
URL: https://github.com/apache/datafusion-comet/pull/307#discussion_r1583889008


##########
core/src/execution/datafusion/expressions/cast.rs:
##########
@@ -142,6 +230,193 @@ impl Cast {
     }
 }
 
+fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i8>> {
+    Ok(cast_string_to_int_with_range_check(
+        str,
+        eval_mode,
+        "TINYINT",
+        i8::MIN as i32,
+        i8::MAX as i32,
+    )?
+    .map(|v| v as i8))
+}
+
+fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i16>> {
+    Ok(cast_string_to_int_with_range_check(
+        str,
+        eval_mode,
+        "SMALLINT",
+        i16::MIN as i32,
+        i16::MAX as i32,
+    )?
+    .map(|v| v as i16))
+}
+
+fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i32>> {
+    do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
+}
+
+fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i64>> {
+    do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
+}
+
+fn cast_string_to_int_with_range_check(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min: i32,
+    max: i32,
+) -> CometResult<Option<i32>> {
+    match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
+        None => Ok(None),
+        Some(v) if v >= min && v <= max => Ok(Some(v)),
+        _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", 
type_name)),
+        _ => Ok(None),
+    }
+}
+
+#[derive(PartialEq)]
+enum State {
+    SkipLeadingWhiteSpace,
+    SkipTrailingWhiteSpace,
+    ParseSignAndDigits,
+    ParseFractionalDigits,
+}
+
+fn do_cast_string_to_int<
+    T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
+>(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min_value: T,
+) -> CometResult<Option<T>> {
+    let len = str.len();
+    if len == 0 {
+        return none_or_err(eval_mode, type_name, str);
+    }
+
+    let mut result: T = T::zero();
+    let mut negative = false;
+    let radix = T::from(10);
+    let stop_value = min_value / radix;
+    let mut state = State::SkipLeadingWhiteSpace;
+    let mut parsed_sign = false;
+
+    for (i, ch) in str.char_indices() {
+        // skip leading whitespace
+        if state == State::SkipLeadingWhiteSpace {
+            if ch.is_whitespace() {
+                // consume this char
+                continue;
+            }
+            // change state and fall through to next section
+            state = State::ParseSignAndDigits;
+        }
+
+        if state == State::ParseSignAndDigits {
+            if !parsed_sign {
+                negative = ch == '-';
+                let positive = ch == '+';
+                parsed_sign = true;
+                if negative || positive {
+                    if i + 1 == len {
+                        // input string is just "+" or "-"
+                        return none_or_err(eval_mode, type_name, str);
+                    }
+                    // consume this char
+                    continue;
+                }
+            }
+
+            if ch == '.' {
+                if eval_mode == EvalMode::Legacy {
+                    // truncate decimal in legacy mode
+                    state = State::ParseFractionalDigits;
+                    continue;
+                } else {
+                    return none_or_err(eval_mode, type_name, str);
+                }
+            }
+
+            let digit = if ch.is_ascii_digit() {
+                (ch as u32) - ('0' as u32)
+            } else {
+                return none_or_err(eval_mode, type_name, str);
+            };
+
+            // We are going to process the new digit and accumulate the 
result. However, before
+            // doing this, if the result is already smaller than the
+            // stopValue(Integer.MIN_VALUE / radix), then result * 10 will 
definitely be
+            // smaller than minValue, and we can stop
+            if result < stop_value {
+                return none_or_err(eval_mode, type_name, str);
+            }
+
+            // Since the previous result is less than or equal to 
stopValue(Integer.MIN_VALUE /
+            // radix), we can just use `result > 0` to check overflow. If 
result
+            // overflows, we should stop

Review Comment:
   The comment was copied from the Spark code in 
`org/apache/spark/unsafe/types/UTF8String.java`, but I agree that it seems 
incorrect. I have updated it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to