This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new cbf4730e feat: Implement Spark-compatible CAST from string to integral
types (#307)
cbf4730e is described below
commit cbf4730ed7ba0558bca32dd2d6ed9175a00c1970
Author: Andy Grove <[email protected]>
AuthorDate: Wed May 1 01:38:35 2024 -0600
feat: Implement Spark-compatible CAST from string to integral types (#307)
---
core/Cargo.toml | 3 +
core/benches/cast.rs | 85 ++++++
core/src/execution/datafusion/expressions/cast.rs | 326 ++++++++++++++++++++-
core/src/execution/datafusion/mod.rs | 2 +-
.../scala/org/apache/comet/CometCastSuite.scala | 81 +++--
5 files changed, 475 insertions(+), 22 deletions(-)
diff --git a/core/Cargo.toml b/core/Cargo.toml
index b09b0ea7..cbca7f62 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -118,3 +118,6 @@ harness = false
name = "row_columnar"
harness = false
+[[bench]]
+name = "cast"
+harness = false
diff --git a/core/benches/cast.rs b/core/benches/cast.rs
new file mode 100644
index 00000000..281fe82e
--- /dev/null
+++ b/core/benches/cast.rs
@@ -0,0 +1,85 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_array::{builder::StringBuilder, RecordBatch};
+use arrow_schema::{DataType, Field, Schema};
+use comet::execution::datafusion::expressions::cast::{Cast, EvalMode};
+use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
+use std::sync::Arc;
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8,
true)]));
+ let mut b = StringBuilder::new();
+ for i in 0..1000 {
+ if i % 10 == 0 {
+ b.append_null();
+ } else if i % 2 == 0 {
+ b.append_value(format!("{}", rand::random::<f64>()));
+ } else {
+ b.append_value(format!("{}", rand::random::<i64>()));
+ }
+ }
+ let array = b.finish();
+ let batch = RecordBatch::try_new(schema.clone(),
vec![Arc::new(array)]).unwrap();
+ let expr = Arc::new(Column::new("a", 0));
+ let timezone = "".to_string();
+ let cast_string_to_i8 = Cast::new(
+ expr.clone(),
+ DataType::Int8,
+ EvalMode::Legacy,
+ timezone.clone(),
+ );
+ let cast_string_to_i16 = Cast::new(
+ expr.clone(),
+ DataType::Int16,
+ EvalMode::Legacy,
+ timezone.clone(),
+ );
+ let cast_string_to_i32 = Cast::new(
+ expr.clone(),
+ DataType::Int32,
+ EvalMode::Legacy,
+ timezone.clone(),
+ );
+ let cast_string_to_i64 = Cast::new(expr, DataType::Int64,
EvalMode::Legacy, timezone);
+
+ let mut group = c.benchmark_group("cast");
+ group.bench_function("cast_string_to_i8", |b| {
+ b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap());
+ });
+ group.bench_function("cast_string_to_i16", |b| {
+ b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap());
+ });
+ group.bench_function("cast_string_to_i32", |b| {
+ b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap());
+ });
+ group.bench_function("cast_string_to_i64", |b| {
+ b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap());
+ });
+}
+
+fn config() -> Criterion {
+ Criterion::default()
+}
+
+criterion_group! {
+ name = benches;
+ config = config();
+ targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/core/src/execution/datafusion/expressions/cast.rs
b/core/src/execution/datafusion/expressions/cast.rs
index 10079855..f5839fd4 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -28,11 +28,15 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
-use arrow_array::{Array, ArrayRef, BooleanArray, GenericStringArray,
OffsetSizeTrait};
+use arrow_array::{
+ types::{Int16Type, Int32Type, Int64Type, Int8Type},
+ Array, ArrayRef, BooleanArray, GenericStringArray, OffsetSizeTrait,
PrimitiveArray,
+};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
+use num::{traits::CheckedNeg, CheckedSub, Integer, Num};
use crate::execution::datafusion::expressions::utils::{
array_with_timezone, down_cast_any_ref, spark_cast,
@@ -64,6 +68,24 @@ pub struct Cast {
pub timezone: String,
}
+macro_rules! cast_utf8_to_int {
+ ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
+ let len = $array.len();
+ let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
+ for i in 0..len {
+ if $array.is_null(i) {
+ cast_array.append_null()
+ } else if let Some(cast_value) =
$cast_method($array.value(i).trim(), $eval_mode)? {
+ cast_array.append_value(cast_value);
+ } else {
+ cast_array.append_null()
+ }
+ }
+ let result: CometResult<ArrayRef> = Ok(Arc::new(cast_array.finish())
as ArrayRef);
+ result
+ }};
+}
+
impl Cast {
pub fn new(
child: Arc<dyn PhysicalExpr>,
@@ -103,10 +125,79 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array,
self.eval_mode)?
}
- _ => cast_with_options(&array, to_type, &CAST_OPTIONS)?,
+ (
+ DataType::Utf8,
+ DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64,
+ ) => Self::cast_string_to_int::<i32>(to_type, &array,
self.eval_mode)?,
+ (
+ DataType::LargeUtf8,
+ DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64,
+ ) => Self::cast_string_to_int::<i64>(to_type, &array,
self.eval_mode)?,
+ (
+ DataType::Dictionary(key_type, value_type),
+ DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64,
+ ) if key_type.as_ref() == &DataType::Int32
+ && (value_type.as_ref() == &DataType::Utf8
+ || value_type.as_ref() == &DataType::LargeUtf8) =>
+ {
+ // TODO: we are unpacking a dictionary-encoded array and then
performing
+ // the cast. We could potentially improve performance here by
casting the
+ // dictionary values directly without unpacking the array
first, although this
+ // would add more complexity to the code
+ match value_type.as_ref() {
+ DataType::Utf8 => {
+ let unpacked_array =
+ cast_with_options(&array, &DataType::Utf8,
&CAST_OPTIONS)?;
+ Self::cast_string_to_int::<i32>(to_type,
&unpacked_array, self.eval_mode)?
+ }
+ DataType::LargeUtf8 => {
+ let unpacked_array =
+ cast_with_options(&array, &DataType::LargeUtf8,
&CAST_OPTIONS)?;
+ Self::cast_string_to_int::<i64>(to_type,
&unpacked_array, self.eval_mode)?
+ }
+ dt => unreachable!(
+ "{}",
+ format!("invalid value type {dt} for
dictionary-encoded string array")
+ ),
+ }
+ }
+ _ => {
+ // when we have no Spark-specific casting we delegate to
DataFusion
+ cast_with_options(&array, to_type, &CAST_OPTIONS)?
+ }
+ };
+ Ok(spark_cast(cast_result, from_type, to_type))
+ }
+
+ fn cast_string_to_int<OffsetSize: OffsetSizeTrait>(
+ to_type: &DataType,
+ array: &ArrayRef,
+ eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<OffsetSize>>()
+ .expect("cast_string_to_int expected a string array");
+
+ let cast_array: ArrayRef = match to_type {
+ DataType::Int8 => {
+ cast_utf8_to_int!(string_array, eval_mode, Int8Type,
cast_string_to_i8)?
+ }
+ DataType::Int16 => {
+ cast_utf8_to_int!(string_array, eval_mode, Int16Type,
cast_string_to_i16)?
+ }
+ DataType::Int32 => {
+ cast_utf8_to_int!(string_array, eval_mode, Int32Type,
cast_string_to_i32)?
+ }
+ DataType::Int64 => {
+ cast_utf8_to_int!(string_array, eval_mode, Int64Type,
cast_string_to_i64)?
+ }
+ dt => unreachable!(
+ "{}",
+ format!("invalid integer type {dt} in cast from string")
+ ),
};
- let result = spark_cast(cast_result, from_type, to_type);
- Ok(result)
+ Ok(cast_array)
}
fn spark_cast_utf8_to_boolean<OffsetSize>(
@@ -142,6 +233,202 @@ impl Cast {
}
}
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte
+fn cast_string_to_i8(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i8>> {
+ Ok(cast_string_to_int_with_range_check(
+ str,
+ eval_mode,
+ "TINYINT",
+ i8::MIN as i32,
+ i8::MAX as i32,
+ )?
+ .map(|v| v as i8))
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort
+fn cast_string_to_i16(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i16>> {
+ Ok(cast_string_to_int_with_range_check(
+ str,
+ eval_mode,
+ "SMALLINT",
+ i16::MIN as i32,
+ i16::MAX as i32,
+ )?
+ .map(|v| v as i16))
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper
intWrapper)
+fn cast_string_to_i32(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i32>> {
+ do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
+}
+
+/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper
intWrapper)
+fn cast_string_to_i64(str: &str, eval_mode: EvalMode) ->
CometResult<Option<i64>> {
+ do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
+}
+
+fn cast_string_to_int_with_range_check(
+ str: &str,
+ eval_mode: EvalMode,
+ type_name: &str,
+ min: i32,
+ max: i32,
+) -> CometResult<Option<i32>> {
+ match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? {
+ None => Ok(None),
+ Some(v) if v >= min && v <= max => Ok(Some(v)),
+ _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING",
type_name)),
+ _ => Ok(None),
+ }
+}
+
+#[derive(PartialEq)]
+enum State {
+ SkipLeadingWhiteSpace,
+ SkipTrailingWhiteSpace,
+ ParseSignAndDigits,
+ ParseFractionalDigits,
+}
+
+/// Equivalent to
+/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper,
boolean allowDecimal)
+/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper,
boolean allowDecimal)
+fn do_cast_string_to_int<
+ T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From<i32> + Copy,
+>(
+ str: &str,
+ eval_mode: EvalMode,
+ type_name: &str,
+ min_value: T,
+) -> CometResult<Option<T>> {
+ let len = str.len();
+ if str.is_empty() {
+ return none_or_err(eval_mode, type_name, str);
+ }
+
+ let mut result: T = T::zero();
+ let mut negative = false;
+ let radix = T::from(10);
+ let stop_value = min_value / radix;
+ let mut state = State::SkipLeadingWhiteSpace;
+ let mut parsed_sign = false;
+
+ for (i, ch) in str.char_indices() {
+ // skip leading whitespace
+ if state == State::SkipLeadingWhiteSpace {
+ if ch.is_whitespace() {
+ // consume this char
+ continue;
+ }
+ // change state and fall through to next section
+ state = State::ParseSignAndDigits;
+ }
+
+ if state == State::ParseSignAndDigits {
+ if !parsed_sign {
+ negative = ch == '-';
+ let positive = ch == '+';
+ parsed_sign = true;
+ if negative || positive {
+ if i + 1 == len {
+ // input string is just "+" or "-"
+ return none_or_err(eval_mode, type_name, str);
+ }
+ // consume this char
+ continue;
+ }
+ }
+
+ if ch == '.' {
+ if eval_mode == EvalMode::Legacy {
+ // truncate decimal in legacy mode
+ state = State::ParseFractionalDigits;
+ continue;
+ } else {
+ return none_or_err(eval_mode, type_name, str);
+ }
+ }
+
+ let digit = if ch.is_ascii_digit() {
+ (ch as u32) - ('0' as u32)
+ } else {
+ return none_or_err(eval_mode, type_name, str);
+ };
+
+ // We are going to process the new digit and accumulate the
result. However, before
+ // doing this, if the result is already smaller than the
+ // stopValue(Integer.MIN_VALUE / radix), then result * 10 will
definitely be
+ // smaller than minValue, and we can stop
+ if result < stop_value {
+ return none_or_err(eval_mode, type_name, str);
+ }
+
+ // Since the previous result is greater than or equal to
stopValue(Integer.MIN_VALUE /
+ // radix), we can just use `result > 0` to check overflow. If
result
+ // overflows, we should stop
+ let v = result * radix;
+ let digit = (digit as i32).into();
+ match v.checked_sub(&digit) {
+ Some(x) if x <= T::zero() => result = x,
+ _ => {
+ return none_or_err(eval_mode, type_name, str);
+ }
+ }
+ }
+
+ if state == State::ParseFractionalDigits {
+ // This is the case when we've encountered a decimal separator.
The fractional
+ // part will not change the number, but we will verify that the
fractional part
+ // is well-formed.
+ if ch.is_whitespace() {
+ // finished parsing fractional digits, now need to skip
trailing whitespace
+ state = State::SkipTrailingWhiteSpace;
+ // consume this char
+ continue;
+ }
+ if !ch.is_ascii_digit() {
+ return none_or_err(eval_mode, type_name, str);
+ }
+ }
+
+ // skip trailing whitespace
+ if state == State::SkipTrailingWhiteSpace && !ch.is_whitespace() {
+ return none_or_err(eval_mode, type_name, str);
+ }
+ }
+
+ if !negative {
+ if let Some(neg) = result.checked_neg() {
+ if neg < T::zero() {
+ return none_or_err(eval_mode, type_name, str);
+ }
+ result = neg;
+ } else {
+ return none_or_err(eval_mode, type_name, str);
+ }
+ }
+
+ Ok(Some(result))
+}
+
+/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on
the evaluation mode
+#[inline]
+fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) ->
CometResult<Option<T>> {
+ match eval_mode {
+ EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
+ _ => Ok(None),
+ }
+}
+
+#[inline]
+fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError {
+ CometError::CastInvalidValue {
+ value: value.to_string(),
+ from_type: from_type.to_string(),
+ to_type: to_type.to_string(),
+ }
+}
+
impl Display for Cast {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
@@ -222,3 +509,34 @@ impl PhysicalExpr for Cast {
self.hash(&mut s);
}
}
+
+#[cfg(test)]
+mod test {
+ use super::{cast_string_to_i8, EvalMode};
+
+ #[test]
+ fn test_cast_string_as_i8() {
+ // basic
+ assert_eq!(
+ cast_string_to_i8("127", EvalMode::Legacy).unwrap(),
+ Some(127_i8)
+ );
+ assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None);
+ assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err());
+ // decimals
+ assert_eq!(
+ cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(),
+ Some(0_i8)
+ );
+ assert_eq!(
+ cast_string_to_i8(".", EvalMode::Legacy).unwrap(),
+ Some(0_i8)
+ );
+ // TRY should always return null for decimals
+ assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None);
+ assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None);
+ // ANSI mode should throw error on decimal
+ assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err());
+ assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err());
+ }
+}
diff --git a/core/src/execution/datafusion/mod.rs
b/core/src/execution/datafusion/mod.rs
index c464eeed..76f0b1c7 100644
--- a/core/src/execution/datafusion/mod.rs
+++ b/core/src/execution/datafusion/mod.rs
@@ -17,7 +17,7 @@
//! Native execution through DataFusion
-mod expressions;
+pub mod expressions;
mod operators;
pub mod planner;
pub(crate) mod shuffle_writer;
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index c6a7c722..1bddedde 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -40,7 +40,12 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
// but this is likely a reasonable starting point for now
private val whitespaceChars = " \t\r\n"
- private val numericPattern = "0123456789e+-." + whitespaceChars
+ /**
+ * We use these characters to construct strings that potentially represent
valid numbers such as
+ * `-12.34d` or `4e7`. Invalid numeric strings will also be generated, such
as `+e.-d`.
+ */
+ private val numericPattern = "0123456789deEf+-." + whitespaceChars
+
private val datePattern = "0123456789/" + whitespaceChars
private val timestampPattern = "0123456789/:T" + whitespaceChars
@@ -433,23 +438,64 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(testValues, DataTypes.BooleanType)
}
- ignore("cast StringType to ByteType") {
- // https://github.com/apache/datafusion-comet/issues/15
- castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ByteType)
- }
-
- ignore("cast StringType to ShortType") {
- // https://github.com/apache/datafusion-comet/issues/15
- castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.ShortType)
- }
-
- ignore("cast StringType to IntegerType") {
- // https://github.com/apache/datafusion-comet/issues/15
+ private val castStringToIntegralInputs: Seq[String] = Seq(
+ "",
+ ".",
+ "+",
+ "-",
+ "+.",
+ "-.",
+ "-0",
+ "+1",
+ "-1",
+ ".2",
+ "-.2",
+ "1e1",
+ "1.1d",
+ "1.1f",
+ Byte.MinValue.toString,
+ (Byte.MinValue.toShort - 1).toString,
+ Byte.MaxValue.toString,
+ (Byte.MaxValue.toShort + 1).toString,
+ Short.MinValue.toString,
+ (Short.MinValue.toInt - 1).toString,
+ Short.MaxValue.toString,
+ (Short.MaxValue.toInt + 1).toString,
+ Int.MinValue.toString,
+ (Int.MinValue.toLong - 1).toString,
+ Int.MaxValue.toString,
+ (Int.MaxValue.toLong + 1).toString,
+ Long.MinValue.toString,
+ Long.MaxValue.toString,
+ "-9223372036854775809", // Long.MinValue -1
+ "9223372036854775808" // Long.MaxValue + 1
+ )
+
+ test("cast StringType to ByteType") {
+ // test with hand-picked values
+ castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ByteType)
+ // fuzz test
+ castTest(generateStrings(numericPattern, 4).toDF("a"), DataTypes.ByteType)
+ }
+
+ test("cast StringType to ShortType") {
+ // test with hand-picked values
+ castTest(castStringToIntegralInputs.toDF("a"), DataTypes.ShortType)
+ // fuzz test
+ castTest(generateStrings(numericPattern, 5).toDF("a"), DataTypes.ShortType)
+ }
+
+ test("cast StringType to IntegerType") {
+ // test with hand-picked values
+ castTest(castStringToIntegralInputs.toDF("a"), DataTypes.IntegerType)
+ // fuzz test
castTest(generateStrings(numericPattern, 8).toDF("a"),
DataTypes.IntegerType)
}
- ignore("cast StringType to LongType") {
- // https://github.com/apache/datafusion-comet/issues/15
+ test("cast StringType to LongType") {
+ // test with hand-picked values
+ castTest(castStringToIntegralInputs.toDF("a"), DataTypes.LongType)
+ // fuzz test
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType)
}
@@ -724,11 +770,12 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
// cast() should return null for invalid inputs when ansi mode is
disabled
- val df = data.withColumn("converted", col("a").cast(toType))
+ val df = spark.sql(s"select a, cast(a as ${toType.sql}) from t order
by a")
checkSparkAnswer(df)
// try_cast() should always return null for invalid inputs
- val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")
+ val df2 =
+ spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by
a")
checkSparkAnswer(df2)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]