This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new cbf4730e feat: Implement Spark-compatible CAST from string to integral 
types (#307)
cbf4730e is described below

commit cbf4730ed7ba0558bca32dd2d6ed9175a00c1970
Author: Andy Grove <[email protected]>
AuthorDate: Wed May 1 01:38:35 2024 -0600

    feat: Implement Spark-compatible CAST from string to integral types (#307)
---
 core/Cargo.toml                                    |   3 +
 core/benches/cast.rs                               |  85 ++++++
 core/src/execution/datafusion/expressions/cast.rs  | 326 ++++++++++++++++++++-
 core/src/execution/datafusion/mod.rs               |   2 +-
 .../scala/org/apache/comet/CometCastSuite.scala    |  81 +++--
 5 files changed, 475 insertions(+), 22 deletions(-)

diff --git a/core/Cargo.toml b/core/Cargo.toml
index b09b0ea7..cbca7f62 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -118,3 +118,6 @@ harness = false
 name = "row_columnar"
 harness = false
 
+[[bench]]
+name = "cast"
+harness = false
diff --git a/core/benches/cast.rs b/core/benches/cast.rs
new file mode 100644
index 00000000..281fe82e
--- /dev/null
+++ b/core/benches/cast.rs
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_array::{builder::StringBuilder, RecordBatch};
+use arrow_schema::{DataType, Field, Schema};
+use comet::execution::datafusion::expressions::cast::{Cast, EvalMode};
+use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
+use std::sync::Arc;
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, 
true)]));
+    let mut b = StringBuilder::new();
+    for i in 0..1000 {
+        if i % 10 == 0 {
+            b.append_null();
+        } else if i % 2 == 0 {
+            b.append_value(format!("{}", rand::random::<f64>()));
+        } else {
+            b.append_value(format!("{}", rand::random::<i64>()));
+        }
+    }
+    let array = b.finish();
+    let batch = RecordBatch::try_new(schema.clone(), 
vec![Arc::new(array)]).unwrap();
+    let expr = Arc::new(Column::new("a", 0));
+    let timezone = "".to_string();
+    let cast_string_to_i8 = Cast::new(
+        expr.clone(),
+        DataType::Int8,
+        EvalMode::Legacy,
+        timezone.clone(),
+    );
+    let cast_string_to_i16 = Cast::new(
+        expr.clone(),
+        DataType::Int16,
+        EvalMode::Legacy,
+        timezone.clone(),
+    );
+    let cast_string_to_i32 = Cast::new(
+        expr.clone(),
+        DataType::Int32,
+        EvalMode::Legacy,
+        timezone.clone(),
+    );
+    let cast_string_to_i64 = Cast::new(expr, DataType::Int64, 
EvalMode::Legacy, timezone);
+
+    let mut group = c.benchmark_group("cast");
+    group.bench_function("cast_string_to_i8", |b| {
+        b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap());
+    });
+    group.bench_function("cast_string_to_i16", |b| {
+        b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap());
+    });
+    group.bench_function("cast_string_to_i32", |b| {
+        b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap());
+    });
+    group.bench_function("cast_string_to_i64", |b| {
+        b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap());
+    });
+}
+
+fn config() -> Criterion {
+    Criterion::default()
+}
+
+criterion_group! {
+    name = benches;
+    config = config();
+    targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/core/src/execution/datafusion/expressions/cast.rs 
b/core/src/execution/datafusion/expressions/cast.rs
index 10079855..f5839fd4 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -28,11 +28,15 @@ use arrow::{
     record_batch::RecordBatch,
     util::display::FormatOptions,
 };
-use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray, 
OffsetSizeTrait};
+use arrow_array::{
+    types::{Int16Type, Int32Type, Int64Type, Int8Type},
+    Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait, 
PrimitiveArray,
+};
 use arrow_schema::{DataType, Schema};
 use datafusion::logical_expr::ColumnarValue;
 use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
 use datafusion_physical_expr::PhysicalExpr;
+use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
 
 use crate::execution::datafusion::expressions::utils::{
     array_with_timezone, down_cast_any_ref, spark_cast,
@@ -64,6 +68,24 @@ pub struct Cast {
     pub timezone: String,
 }
 
+macro_rules! cast_utf8_to_int {
+    ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
+        let len = $array.len();
+        let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
+        for i in 0..len {
+            if $array.is_null(i) {
+                cast_array.append_null()
+            } else if let Some(cast_value) = 
$cast_method($array.value(i).trim(), $eval_mode)? {
+                cast_array.append_value(cast_value);
+            } else {
+                cast_array.append_null()
+            }
+        }
+        let result: CometResult<ArrayRef> = Ok(Arc::new(cast_array.finish()) 
as ArrayRef);
+        result
+    }};
+}
+
 impl Cast {
     pub fn new(
         child: Arc<dyn PhysicalExpr>,
@@ -103,10 +125,79 @@ impl Cast {
             (DataType::LargeUtf8, DataType::Boolean) => {
                 Self::spark_cast_utf8_to_boolean::<i64>(&array, 
self.eval_mode)?
             }
-            _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
+            (
+                DataType::Utf8,
+                DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
+            ) => Self::cast_string_to_int::<i32>(to_type, &array, 
self.eval_mode)?,
+            (
+                DataType::LargeUtf8,
+                DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
+            ) => Self::cast_string_to_int::<i64>(to_type, &array, 
self.eval_mode)?,
+            (
+                DataType::Dictionary(key_type, value_type),
+                DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
+            ) if key_type.as_ref() == &DataType::Int32
+                && (value_type.as_ref() == &DataType::Utf8
+                    || value_type.as_ref() == &DataType::LargeUtf8) =>
+            {
+                // TODO: we are unpacking a dictionary-encoded array and then 
performing
+                // the cast. We could potentially improve performance here by 
casting the
+                // dictionary values directly without unpacking the array 
first, although this
+                // would add more complexity to the code
+                match value_type.as_ref() {
+                    DataType::Utf8 => {
+                        let unpacked_array =
+                            cast_with_options(&array, &DataType::Utf8, 
&CAST_OPTIONS)?;
+                        Self::cast_string_to_int::<i32>(to_type, 
&unpacked_array, self.eval_mode)?
+                    }
+                    DataType::LargeUtf8 => {
+                        let unpacked_array =
+                            cast_with_options(&array, &DataType::LargeUtf8, 
&CAST_OPTIONS)?;
+                        Self::cast_string_to_int::<i64>(to_type, 
&unpacked_array, self.eval_mode)?
+                    }
+                    dt => unreachable!(
+                        "{}",
+                        format!("invalid value type {dt} for 
dictionary-encoded string array")
+                    ),
+                }
+            }
+            _ => {
+                // when we have no Spark-specific casting we delegate to 
DataFusion
+                cast_with_options(&array, to_type, &CAST_OPTIONS)?
+            }
+        };
+        Ok(spark_cast(cast_result, from_type, to_type))
+    }
+
+    fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
+        to_type: &DataType,
+        array: &ArrayRef,
+        eval_mode: EvalMode,
+    ) -> CometResult<ArrayRef> {
+        let string_array = array
+            .as_any()
+            .downcast_ref::<GenericStringArray<OffsetSize>>()
+            .expect("cast_string_to_int expected a string array");
+
+        let cast_array: ArrayRef = match to_type {
+            DataType::Int8 => {
+                cast_utf8_to_int!(string_array, eval_mode, Int8Type, 
cast_string_to_i8)?
+            }
+            DataType::Int16 => {
+                cast_utf8_to_int!(string_array, eval_mode, Int16Type, 
cast_string_to_i16)?
+            }
+            DataType::Int32 => {
+                cast_utf8_to_int!(string_array, eval_mode, Int32Type, 
cast_string_to_i32)?
+            }
+            DataType::Int64 => {
+                cast_utf8_to_int!(string_array, eval_mode, Int64Type, 
cast_string_to_i64)?
+            }
+            dt => unreachable!(
+                "{}",
+                format!("invalid integer type {dt} in cast from string")
+            ),
         };
-        let result = spark_cast(cast_result, from_type, to_type);
-        Ok(result)
+        Ok(cast_array)
     }
 
     fn spark_cast_utf8_to_boolean<OffsetSize>(
@@ -142,6 +233,202 @@ impl Cast {
     }
 }
 
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
+fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i8>> {
+    Ok(cast_string_to_int_with_range_check(
+        str,
+        eval_mode,
+        "TINYINT",
+        i8::MIN as i32,
+        i8::MAX as i32,
+    )?
+    .map(|v| v as i8))
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
+fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i16>> {
+    Ok(cast_string_to_int_with_range_check(
+        str,
+        eval_mode,
+        "SMALLINT",
+        i16::MIN as i32,
+        i16::MAX as i32,
+    )?
+    .map(|v| v as i16))
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper 
intWrapper)
+fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i32>> {
+    do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper 
intWrapper)
+fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> 
CometResult<Option<i64>> {
+    do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
+}
+
+fn cast_string_to_int_with_range_check(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min: i32,
+    max: i32,
+) -> CometResult<Option<i32>> {
+    match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
+        None => Ok(None),
+        Some(v) if v >= min && v <= max => Ok(Some(v)),
+        _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", 
type_name)),
+        _ => Ok(None),
+    }
+}
+
+#[derive(PartialEq)]
+enum State {
+    SkipLeadingWhiteSpace,
+    SkipTrailingWhiteSpace,
+    ParseSignAndDigits,
+    ParseFractionalDigits,
+}
+
+/// Equivalent to
+/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, 
boolean allowDecimal)
+/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, 
boolean allowDecimal)
+fn do_cast_string_to_int<
+    T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
+>(
+    str: &str,
+    eval_mode: EvalMode,
+    type_name: &str,
+    min_value: T,
+) -> CometResult<Option<T>> {
+    let len = str.len();
+    if str.is_empty() {
+        return none_or_err(eval_mode, type_name, str);
+    }
+
+    let mut result: T = T::zero();
+    let mut negative = false;
+    let radix = T::from(10);
+    let stop_value = min_value / radix;
+    let mut state = State::SkipLeadingWhiteSpace;
+    let mut parsed_sign = false;
+
+    for (i, ch) in str.char_indices() {
+        // skip leading whitespace
+        if state == State::SkipLeadingWhiteSpace {
+            if ch.is_whitespace() {
+                // consume this char
+                continue;
+            }
+            // change state and fall through to next section
+            state = State::ParseSignAndDigits;
+        }
+
+        if state == State::ParseSignAndDigits {
+            if !parsed_sign {
+                negative = ch == '-';
+                let positive = ch == '+';
+                parsed_sign = true;
+                if negative || positive {
+                    if i + 1 == len {
+                        // input string is just "+" or "-"
+                        return none_or_err(eval_mode, type_name, str);
+                    }
+                    // consume this char
+                    continue;
+                }
+            }
+
+            if ch == '.' {
+                if eval_mode == EvalMode::Legacy {
+                    // truncate decimal in legacy mode
+                    state = State::ParseFractionalDigits;
+                    continue;
+                } else {
+                    return none_or_err(eval_mode, type_name, str);
+                }
+            }
+
+            let digit = if ch.is_ascii_digit() {
+                (ch as u32) - ('0' as u32)
+            } else {
+                return none_or_err(eval_mode, type_name, str);
+            };
+
+            // We are going to process the new digit and accumulate the 
result. However, before
+            // doing this, if the result is already smaller than the
+            // stopValue(Integer.MIN_VALUE / radix), then result * 10 will 
definitely be
+            // smaller than minValue, and we can stop
+            if result < stop_value {
+                return none_or_err(eval_mode, type_name, str);
+            }
+
+            // Since the previous result is greater than or equal to 
stopValue(Integer.MIN_VALUE /
+            // radix), we can just use `result > 0` to check overflow. If 
result
+            // overflows, we should stop
+            let v = result * radix;
+            let digit = (digit as i32).into();
+            match v.checked_sub(&digit) {
+                Some(x) if x <= T::zero() => result = x,
+                _ => {
+                    return none_or_err(eval_mode, type_name, str);
+                }
+            }
+        }
+
+        if state == State::ParseFractionalDigits {
+            // This is the case when we've encountered a decimal separator. 
The fractional
+            // part will not change the number, but we will verify that the 
fractional part
+            // is well-formed.
+            if ch.is_whitespace() {
+                // finished parsing fractional digits, now need to skip 
trailing whitespace
+                state = State::SkipTrailingWhiteSpace;
+                // consume this char
+                continue;
+            }
+            if !ch.is_ascii_digit() {
+                return none_or_err(eval_mode, type_name, str);
+            }
+        }
+
+        // skip trailing whitespace
+        if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() {
+            return none_or_err(eval_mode, type_name, str);
+        }
+    }
+
+    if !negative {
+        if let Some(neg) = result.checked_neg() {
+            if neg < T::zero() {
+                return none_or_err(eval_mode, type_name, str);
+            }
+            result = neg;
+        } else {
+            return none_or_err(eval_mode, type_name, str);
+        }
+    }
+
+    Ok(Some(result))
+}
+
+/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on 
the evaluation mode
+#[inline]
+fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> 
CometResult<Option<T>> {
+    match eval_mode {
+        EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
+        _ => Ok(None),
+    }
+}
+
+#[inline]
+fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError {
+    CometError::CastInvalidValue {
+        value: value.to_string(),
+        from_type: from_type.to_string(),
+        to_type: to_type.to_string(),
+    }
+}
+
 impl Display for Cast {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         write!(
@@ -222,3 +509,34 @@ impl PhysicalExpr for Cast {
         self.hash(&mut s);
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::{cast_string_to_i8, EvalMode};
+
+    #[test]
+    fn test_cast_string_as_i8() {
+        // basic
+        assert_eq!(
+            cast_string_to_i8("127", EvalMode::Legacy).unwrap(),
+            Some(127_i8)
+        );
+        assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None);
+        assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err());
+        // decimals
+        assert_eq!(
+            cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(),
+            Some(0_i8)
+        );
+        assert_eq!(
+            cast_string_to_i8(".", EvalMode::Legacy).unwrap(),
+            Some(0_i8)
+        );
+        // TRY should always return null for decimals
+        assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None);
+        assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None);
+        // ANSI mode should throw error on decimal
+        assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err());
+        assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err());
+    }
+}
diff --git a/core/src/execution/datafusion/mod.rs 
b/core/src/execution/datafusion/mod.rs
index c464eeed..76f0b1c7 100644
--- a/core/src/execution/datafusion/mod.rs
+++ b/core/src/execution/datafusion/mod.rs
@@ -17,7 +17,7 @@
 
 //! Native execution through DataFusion
 
-mod expressions;
+pub mod expressions;
 mod operators;
 pub mod planner;
 pub(crate) mod shuffle_writer;
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index c6a7c722..1bddedde 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -40,7 +40,12 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
   // but this is likely a reasonable starting point for now
   private val whitespaceChars = " \t\r\n"
 
-  private val numericPattern = "0123456789e+-." + whitespaceChars
+  /**
+   * We use these characters to construct strings that potentially represent 
valid numbers such as
+   * `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such 
as `+e.-d`.
+   */
+  private val numericPattern = "0123456789deEf+-." + whitespaceChars
+
   private val datePattern = "0123456789/" + whitespaceChars
   private val timestampPattern = "0123456789/:T" + whitespaceChars
 
@@ -433,23 +438,64 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(testValues, DataTypes.BooleanType)
   }
 
-  ignore("cast StringType to ByteType") {
-    // https://github.com/apache/datafusion-comet/issues/15
-    castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ByteType)
-  }
-
-  ignore("cast StringType to ShortType") {
-    // https://github.com/apache/datafusion-comet/issues/15
-    castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType)
-  }
-
-  ignore("cast StringType to IntegerType") {
-    // https://github.com/apache/datafusion-comet/issues/15
+  private val castStringToIntegralInputs: Seq[String] = Seq(
+    "",
+    ".",
+    "+",
+    "-",
+    "+.",
+    "-.",
+    "-0",
+    "+1",
+    "-1",
+    ".2",
+    "-.2",
+    "1e1",
+    "1.1d",
+    "1.1f",
+    Byte.MinValue.toString,
+    (Byte.MinValue.toShort - 1).toString,
+    Byte.MaxValue.toString,
+    (Byte.MaxValue.toShort + 1).toString,
+    Short.MinValue.toString,
+    (Short.MinValue.toInt - 1).toString,
+    Short.MaxValue.toString,
+    (Short.MaxValue.toInt + 1).toString,
+    Int.MinValue.toString,
+    (Int.MinValue.toLong - 1).toString,
+    Int.MaxValue.toString,
+    (Int.MaxValue.toLong + 1).toString,
+    Long.MinValue.toString,
+    Long.MaxValue.toString,
+    "-9223372036854775809", // Long.MinValue -1
+    "9223372036854775808" // Long.MaxValue + 1
+  )
+
+  test("cast StringType to ByteType") {
+    // test with hand-picked values
+    castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
+    // fuzz test
+    castTest(generateStrings(numericPattern, 4).toDF("a"), DataTypes.ByteType)
+  }
+
+  test("cast StringType to ShortType") {
+    // test with hand-picked values
+    castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
+    // fuzz test
+    castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ShortType)
+  }
+
+  test("cast StringType to IntegerType") {
+    // test with hand-picked values
+    castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
+    // fuzz test
     castTest(generateStrings(numericPattern, 8).toDF("a"), 
DataTypes.IntegerType)
   }
 
-  ignore("cast StringType to LongType") {
-    // https://github.com/apache/datafusion-comet/issues/15
+  test("cast StringType to LongType") {
+    // test with hand-picked values
+    castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
+    // fuzz test
     castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType)
   }
 
@@ -724,11 +770,12 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
       withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
         // cast() should return null for invalid inputs when ansi mode is 
disabled
-        val df = data.withColumn("converted", col("a").cast(toType))
+        val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order 
by a")
         checkSparkAnswer(df)
 
         // try_cast() should always return null for invalid inputs
-        val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")
+        val df2 =
+          spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
         checkSparkAnswer(df2)
       }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to