advancedxy commented on code in PR #307:
URL: https://github.com/apache/datafusion-comet/pull/307#discussion_r1578925501
##########
core/src/execution/datafusion/expressions/cast.rs:
##########
@@ -103,10 +125,72 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array,
self.eval_mode)?
}
- _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
+ (
+ DataType::Utf8,
+ DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64,
+ ) => Self::cast_string_to_int(to_type, &array, self.eval_mode)?,
+ (
+ DataType::Dictionary(key_type, value_type),
+ DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64,
+ ) if key_type.as_ref() == &DataType::Int32
+ && value_type.as_ref() == &DataType::Utf8 =>
+ {
+ // Note that we are unpacking a dictionary-encoded array and
then performing
+ // the cast. We could potentially improve performance here by
casting the
+ // dictionary values directly without unpacking the array
first, although this
+ // would add more complexity to the code
Review Comment:
I think we can leave a TODO to cast dictionary directly?
##########
core/src/execution/datafusion/expressions/cast.rs:
##########
@@ -142,6 +226,281 @@ impl Cast {
}
}
+fn cast_string_to_i8(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i8>> {
+ Ok(cast_string_to_int_with_range_check(
+ str,
+ eval_mode,
+ "TINYINT",
+ i8::MIN as i32,
+ i8::MAX as i32,
+ )?
+ .map(|v| v as i8))
+}
+
+fn cast_string_to_i16(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i16>> {
+ Ok(cast_string_to_int_with_range_check(
+ str,
+ eval_mode,
+ "SMALLINT",
+ i16::MIN as i32,
+ i16::MAX as i32,
+ )?
+ .map(|v| v as i16))
+}
+
+fn cast_string_to_i32(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i32>> {
+ let mut accum = CastStringToInt32::default();
+ do_cast_string_to_int(&mut accum, str, eval_mode, "INT")?;
+ Ok(accum.result)
+}
+
+fn cast_string_to_i64(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i64>> {
+ let mut accum = CastStringToInt64::default();
+ do_cast_string_to_int(&mut accum, str, eval_mode, "BIGINT")?;
+ Ok(accum.result)
+}
+
+fn cast_string_to_int_with_range_check(
+ str: &str,
+ eval_mode: EvalMode,
+ type_name: &str,
+ min: i32,
+ max: i32,
+) -> CometResult<Option<i32>> {
+ let mut accum = CastStringToInt32::default();
+ do_cast_string_to_int(&mut accum, str, eval_mode, type_name)?;
+ match accum.result {
+ None => Ok(None),
+ Some(v) if v >= min && v <= max => Ok(Some(v)),
+ _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING",
type_name)),
+ _ => Ok(None),
+ }
+}
+
+/// We support parsing strings to i32 and i64 to match Spark's logic. Support
for i8 and i16 is
+/// implemented by first parsing as i32 and then downcasting. The
CastStringToInt trait is
+/// introduced so that we can have the parsing logic delegate either to an i32
or i64 accumulator
+/// and avoid the need to use macros here.
+trait CastStringToInt {
+ fn accumulate(
+ &mut self,
+ eval_mode: EvalMode,
+ type_name: &str,
+ str: &str,
+ digit: u32,
+ ) -> CometResult<()>;
+
+ fn reset(&mut self);
+
+ fn finish(
+ &mut self,
+ eval_mode: EvalMode,
+ type_name: &str,
+ str: &str,
+ negative: bool,
+ ) -> CometResult<()>;
+}
+struct CastStringToInt32 {
+ negative: bool,
+ result: Option<i32>,
+ radix: i32,
+}
+
+impl Default for CastStringToInt32 {
+ fn default() -> Self {
+ Self {
+ negative: false,
+ result: Some(0),
+ radix: 10,
+ }
+ }
+}
+
+impl CastStringToInt for CastStringToInt32 {
+ fn accumulate(
+ &mut self,
+ eval_mode: EvalMode,
+ type_name: &str,
+ str: &str,
+ digit: u32,
+ ) -> CometResult<()> {
+ if self.result.is_some() && self.result.unwrap() < i32::MIN /
self.radix {
+ self.reset();
+ return none_or_err(eval_mode, type_name, str);
+ }
+ self.result = Some(self.result.unwrap_or(0) * self.radix - digit as
i32);
+ if self.result.unwrap() > 0 {
+ self.reset();
+ return none_or_err(eval_mode, type_name, str);
+ }
+ Ok(())
+ }
+ fn reset(&mut self) {
+ self.result = None;
+ }
+
+ fn finish(
+ &mut self,
+ eval_mode: EvalMode,
+ type_name: &str,
+ str: &str,
+ negative: bool,
+ ) -> CometResult<()> {
+ if self.result.is_some() && !negative {
+ self.result = Some(-self.result.unwrap());
+ if self.result.unwrap() < 0 {
+ return none_or_err(eval_mode, type_name, str);
+ }
+ }
+ Ok(())
+ }
+}
+
+struct CastStringToInt64 {
+ negative: bool,
+ result: Option<i64>,
+ radix: i64,
+}
+
+impl Default for CastStringToInt64 {
+ fn default() -> Self {
+ Self {
+ negative: false,
+ result: Some(0),
+ radix: 10,
+ }
+ }
+}
+
+impl CastStringToInt for CastStringToInt64 {
+ fn accumulate(
+ &mut self,
+ eval_mode: EvalMode,
+ type_name: &str,
+ str: &str,
+ digit: u32,
+ ) -> CometResult<()> {
+ if self.result.unwrap_or(0) < i64::MIN / self.radix {
+ self.reset();
+ return none_or_err(eval_mode, type_name, str);
+ }
+ self.result = Some(self.result.unwrap_or(0) * self.radix - digit as
i64);
+ if self.result.unwrap() > 0 {
+ self.reset();
+ return none_or_err(eval_mode, type_name, str);
+ }
+ Ok(())
+ }
+
+ fn reset(&mut self) {
+ self.result = None;
+ }
+
+ fn finish(
+ &mut self,
+ eval_mode: EvalMode,
+ type_name: &str,
+ str: &str,
+ negative: bool,
+ ) -> CometResult<()> {
+ if self.result.is_some() && !negative {
+ self.result = Some(-self.result.unwrap());
+ if self.result.unwrap() < 0 {
+ return none_or_err(eval_mode, type_name, str);
+ }
+ }
+ Ok(())
+ }
+}
+
+fn do_cast_string_to_int(
+ accumulator: &mut dyn CastStringToInt,
+ str: &str,
+ eval_mode: EvalMode,
+ type_name: &str,
+) -> CometResult<()> {
+ let chars: Vec<char> = str.chars().collect();
+ let mut i = 0;
+ let mut end = chars.len();
+
+ // skip leading whitespace
+ while i < end && chars[i].is_whitespace() {
+ i += 1;
+ }
+
+ // skip trailing whitespace
+ while end > i && chars[end - 1].is_whitespace() {
+ end -= 1;
+ }
+
+ // check for empty string
+ if i == end {
+ accumulator.reset();
+ return Ok(());
+ }
+
+ // skip + or -
+ let negative = chars[0] == '-';
+ if negative || chars[0] == '+' {
Review Comment:
This seems wrong.
It should be `chars[i] == '-'` instead? Otherwise, this cast doesn't work
for ` -124`
##########
core/src/execution/datafusion/expressions/cast.rs:
##########
@@ -103,10 +125,72 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array,
self.eval_mode)?
Review Comment:
Not part of this pr. But if we are going to name the added method as
`cast_string_to_int`.
This method should be renamed to `cast_utf8_to_boolean` as well in a
follow-up PR?
##########
core/src/execution/datafusion/expressions/cast.rs:
##########
@@ -222,3 +581,34 @@ impl PhysicalExpr for Cast {
self.hash(&mut s);
}
}
+
+#[cfg(test)]
+mod test {
+ use super::{cast_string_to_i8, EvalMode};
+
+ #[test]
+ fn test_cast_string_as_i8() {
Review Comment:
how about add more tests about `i32` and `i64` with its min/max and zero
input?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]