andygrove commented on code in PR #307:
URL: https://github.com/apache/datafusion-comet/pull/307#discussion_r1583889160


##########
core/src/execution/datafusion/expressions/cast.rs:
##########
@@ -142,6 +230,193 @@ impl Cast {
     }
 }
 
+fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i8>> {
+    Ok(cast_string_to_int_with_range_check(
+        str,
+        eval_mode,
+        "TINYINT",
+        i8::MIN as i32,
+        i8::MAX as i32,
+    )?
+    .map(|v| v as i8))
+}
+
+fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i16>> {
+    Ok(cast_string_to_int_with_range_check(
+        str,
+        eval_mode,
+        "SMALLINT",
+        i16::MIN as i32,
+        i16::MAX as i32,
+    )?
+    .map(|v| v as i16))
+}
+
+fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i32>> {
+    do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
+}
+
+fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i64>> {
+    do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
+}
+
+fn cast_string_to_int_with_range_check(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min: i32,
+    max: i32,
+) -> CometResult<Option<i32>> {
+    match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
+        None => Ok(None),
+        Some(v) if v >= min && v <= max => Ok(Some(v)),
+        _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", 
type_name)),
+        _ => Ok(None),
+    }
+}
+
+#[derive(PartialEq)]
+enum State {
+    SkipLeadingWhiteSpace,
+    SkipTrailingWhiteSpace,
+    ParseSignAndDigits,
+    ParseFractionalDigits,
+}
+
+fn do_cast_string_to_int<
+    T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
+>(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min_value: T,
+) -> CometResult<Option<T>> {
+    let len = str.len();
+    if len == 0 {
+        return none_or_err(eval_mode, type_name, str);
+    }
+
+    let mut result: T = T::zero();
+    let mut negative = false;
+    let radix = T::from(10);
+    let stop_value = min_value / radix;
+    let mut state = State::SkipLeadingWhiteSpace;
+    let mut parsed_sign = false;
+
+    for (i, ch) in str.char_indices() {
+        // skip leading whitespace
+        if state == State::SkipLeadingWhiteSpace {
+            if ch.is_whitespace() {
+                // consume this char
+                continue;
+            }
+            // change state and fall through to next section
+            state = State::ParseSignAndDigits;
+        }
+
+        if state == State::ParseSignAndDigits {
+            if !parsed_sign {
+                negative = ch == '-';
+                let positive = ch == '+';
+                parsed_sign = true;
+                if negative || positive {
+                    if i + 1 == len {
+                        // input string is just "+" or "-"
+                        return none_or_err(eval_mode, type_name, str);
+                    }
+                    // consume this char
+                    continue;
+                }
+            }
+
+            if ch == '.' {
+                if eval_mode == EvalMode::Legacy {
+                    // truncate decimal in legacy mode
+                    state = State::ParseFractionalDigits;
+                    continue;
+                } else {
+                    return none_or_err(eval_mode, type_name, str);
+                }
+            }
+
+            let digit = if ch.is_ascii_digit() {
+                (ch as u32) - ('0' as u32)
+            } else {
+                return none_or_err(eval_mode, type_name, str);
+            };
+
+            // We are going to process the new digit and accumulate the 
result. However, before
+            // doing this, if the result is already smaller than the
+            // stopValue(Integer.MIN_VALUE / radix), then result * 10 will 
definitely be
+            // smaller than minValue, and we can stop
+            if result < stop_value {
+                return none_or_err(eval_mode, type_name, str);
+            }
+
+            // Since the previous result is less than or equal to 
stopValue(Integer.MIN_VALUE /
+            // radix), we can just use `result > 0` to check overflow. If 
result
+            // overflows, we should stop
+            let v = result * radix;
+            let digit = (digit as i32).into();
+            match v.checked_sub(&digit) {
+                Some(x) if x <= T::zero() => result = x,
+                _ => {
+                    return none_or_err(eval_mode, type_name, str);
+                }
+            }
+        }
+
+        if state == State::ParseFractionalDigits {
+            // This is the case when we've encountered a decimal separator. 
The fractional
+            // part will not change the number, but we will verify that the 
fractional part
+            // is well-formed.
+            if ch.is_whitespace() {
+                // finished parsing fractional digits, now need to skip 
trailing whitespace
+                state = State::SkipTrailingWhiteSpace;
+                // consume this char
+                continue;
+            }
+            if !ch.is_ascii_digit() {
+                return none_or_err(eval_mode, type_name, str);
+            }
+        }
+
+        // skip trailing whitespace
+        if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() {
+            return none_or_err(eval_mode, type_name, str);
+        }
+    }
+
+    if !negative {
+        if let Some(neg) = result.checked_neg() {
+            if neg < T::zero() {
+                return none_or_err(eval_mode, type_name, str);
+            }
+            result = neg;
+        } else {
+            return none_or_err(eval_mode, type_name, str);
+        }
+    }
+
+    Ok(Some(result))
+}
+
+/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on 
the evaluation mode
+fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> 
CometResult<Option<T>> {
+    match eval_mode {
+        EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
+        _ => Ok(None),
+    }
+}
+
+fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError {
+    CometError::CastInvalidValue {
+        value: value.to_string(),
+        from_type: from_type.to_string(),
+        to_type: to_type.to_string(),
+    }
+}

Review Comment:
   Thanks. I have updated this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to