andygrove commented on code in PR #307: URL: https://github.com/apache/datafusion-comet/pull/307#discussion_r1583889160
########## core/src/execution/datafusion/expressions/cast.rs: ########## @@ -142,6 +230,193 @@ impl Cast { } } +fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> CometResult<Option<i8>> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "TINYINT", + i8::MIN as i32, + i8::MAX as i32, + )? + .map(|v| v as i8)) +} + +fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> CometResult<Option<i16>> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "SMALLINT", + i16::MIN as i32, + i16::MAX as i32, + )? + .map(|v| v as i16)) +} + +fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> { + do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN) +} + +fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> { + do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN) +} + +fn cast_string_to_int_with_range_check( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min: i32, + max: i32, +) -> CometResult<Option<i32>> { + match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { + None => Ok(None), + Some(v) if v >= min && v <= max => Ok(Some(v)), + _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +#[derive(PartialEq)] +enum State { + SkipLeadingWhiteSpace, + SkipTrailingWhiteSpace, + ParseSignAndDigits, + ParseFractionalDigits, +} + +fn do_cast_string_to_int< + T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy, +>( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: T, +) -> CometResult<Option<T>> { + let len = str.len(); + if len == 0 { + return none_or_err(eval_mode, type_name, str); + } + + let mut result: T = T::zero(); + let mut negative = false; + let radix = T::from(10); + let stop_value = min_value / radix; + let mut state = State::SkipLeadingWhiteSpace; + let mut parsed_sign = false; + + for (i, ch) in str.char_indices() { + // skip leading whitespace + if state == State::SkipLeadingWhiteSpace { + if ch.is_whitespace() { + // consume this char + continue; + } + // change state and fall through to next section + state = State::ParseSignAndDigits; + } + + if state == State::ParseSignAndDigits { + if !parsed_sign { + negative = ch == '-'; + let positive = ch == '+'; + parsed_sign = true; + if negative || positive { + if i + 1 == len { + // input string is just "+" or "-" + return none_or_err(eval_mode, type_name, str); + } + // consume this char + continue; + } + } + + if ch == '.' { + if eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + state = State::ParseFractionalDigits; + continue; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + let digit = if ch.is_ascii_digit() { + (ch as u32) - ('0' as u32) + } else { + return none_or_err(eval_mode, type_name, str); + }; + + // We are going to process the new digit and accumulate the result. However, before + // doing this, if the result is already smaller than the + // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be + // smaller than minValue, and we can stop + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / + // radix), we can just use `result > 0` to check overflow. If result + // overflows, we should stop + let v = result * radix; + let digit = (digit as i32).into(); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } + } + } + + if state == State::ParseFractionalDigits { + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well-formed. + if ch.is_whitespace() { + // finished parsing fractional digits, now need to skip trailing whitespace + state = State::SkipTrailingWhiteSpace; + // consume this char + continue; + } + if !ch.is_ascii_digit() { + return none_or_err(eval_mode, type_name, str); + } + } + + // skip trailing whitespace + if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() { + return none_or_err(eval_mode, type_name, str); + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return none_or_err(eval_mode, type_name, str); + } + result = neg; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + Ok(Some(result)) +} + +/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode +fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<Option<T>> { + match eval_mode { + EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError { + CometError::CastInvalidValue { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} Review Comment: Thanks. I have updated this. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org