This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new a7272b90 feat: Implement Spark-compatible CAST from String to Date
(#383)
a7272b90 is described below
commit a7272b90a7cc903c0b2c7482f1f44c2d325b407e
Author: vidyasankarv <[email protected]>
AuthorDate: Thu May 23 22:58:54 2024 +0530
feat: Implement Spark-compatible CAST from String to Date (#383)
* support casting DateType in comet
* Use NaiveDate methods for parsing dates and remove regex
* remove macro for date parsing
* compute correct days since epoch.
* Run String to Date test without fuzzy test
* port spark string to date processing logic for cast
* put in fixes for clippy and scalafix issues.
* handle non digit characters when parsing segement bytes.
* add note on compatability
* add negative tests for String to Date test
* simplify byte digit check
* propagate error correctly when a date cannot be parsed
* add fuzz test with unsupported date filtering
* add test for string array cast to date
* use UNIX_EPOCH constant from NaiveDateTime
* fix cargo clippy error - collapse else if
* do not run string to date test on spark-3.2
* do not run string to date test on spark-3.2
* add failing fuzz test dates to tests and remove failing fuzz test
* remove unused date pattern
* add specific match for casting dictionary to date
---
core/src/execution/datafusion/expressions/cast.rs | 379 ++++++++++++++++++++-
docs/source/user-guide/compatibility.md | 1 +
.../org/apache/comet/expressions/CometCast.scala | 2 +-
.../scala/org/apache/comet/CometCastSuite.scala | 68 +++-
4 files changed, 429 insertions(+), 21 deletions(-)
diff --git a/core/src/execution/datafusion/expressions/cast.rs
b/core/src/execution/datafusion/expressions/cast.rs
index f68732fb..fd1f9166 100644
--- a/core/src/execution/datafusion/expressions/cast.rs
+++ b/core/src/execution/datafusion/expressions/cast.rs
@@ -22,7 +22,6 @@ use std::{
sync::Arc,
};
-use crate::errors::{CometError, CometResult};
use arrow::{
compute::{cast_with_options, CastOptions},
datatypes::{
@@ -33,23 +32,27 @@ use arrow::{
util::display::FormatOptions,
};
use arrow_array::{
- types::{Int16Type, Int32Type, Int64Type, Int8Type},
+ types::{Date32Type, Int16Type, Int32Type, Int64Type, Int8Type},
Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array,
Float64Array, GenericStringArray,
Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait,
PrimitiveArray,
};
use arrow_schema::{DataType, Schema};
-use chrono::{TimeZone, Timelike};
+use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
use datafusion::logical_expr::ColumnarValue;
use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;
use num::{cast::AsPrimitive, traits::CheckedNeg, CheckedSub, Integer, Num,
ToPrimitive};
use regex::Regex;
-use crate::execution::datafusion::expressions::utils::{
- array_with_timezone, down_cast_any_ref, spark_cast,
+use crate::{
+ errors::{CometError, CometResult},
+ execution::datafusion::expressions::utils::{
+ array_with_timezone, down_cast_any_ref, spark_cast,
+ },
};
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
+
static CAST_OPTIONS: CastOptions = CastOptions {
safe: true,
format_options: FormatOptions::new()
@@ -511,6 +514,31 @@ impl Cast {
(DataType::Utf8, DataType::Timestamp(_, _)) => {
Self::cast_string_to_timestamp(&array, to_type,
self.eval_mode)?
}
+ (DataType::Utf8, DataType::Date32) => {
+ Self::cast_string_to_date(&array, to_type, self.eval_mode)?
+ }
+ (DataType::Dictionary(key_type, value_type), DataType::Date32)
+ if key_type.as_ref() == &DataType::Int32
+ && (value_type.as_ref() == &DataType::Utf8
+ || value_type.as_ref() == &DataType::LargeUtf8) =>
+ {
+ match value_type.as_ref() {
+ DataType::Utf8 => {
+ let unpacked_array =
+ cast_with_options(&array, &DataType::Utf8,
&CAST_OPTIONS)?;
+ Self::cast_string_to_date(&unpacked_array, to_type,
self.eval_mode)?
+ }
+ DataType::LargeUtf8 => {
+ let unpacked_array =
+ cast_with_options(&array, &DataType::LargeUtf8,
&CAST_OPTIONS)?;
+ Self::cast_string_to_date(&unpacked_array, to_type,
self.eval_mode)?
+ }
+ dt => unreachable!(
+ "{}",
+ format!("invalid value type {dt} for
dictionary-encoded string array")
+ ),
+ }
+ }
(DataType::Int64, DataType::Int32)
| (DataType::Int64, DataType::Int16)
| (DataType::Int64, DataType::Int8)
@@ -635,6 +663,38 @@ impl Cast {
Ok(cast_array)
}
+ fn cast_string_to_date(
+ array: &ArrayRef,
+ to_type: &DataType,
+ eval_mode: EvalMode,
+ ) -> CometResult<ArrayRef> {
+ let string_array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<i32>>()
+ .expect("Expected a string array");
+
+ let cast_array: ArrayRef = match to_type {
+ DataType::Date32 => {
+ let len = string_array.len();
+ let mut cast_array =
PrimitiveArray::<Date32Type>::builder(len);
+ for i in 0..len {
+ if !string_array.is_null(i) {
+ match date_parser(string_array.value(i), eval_mode) {
+ Ok(Some(cast_value)) =>
cast_array.append_value(cast_value),
+ Ok(None) => cast_array.append_null(),
+ Err(e) => return Err(e),
+ }
+ } else {
+ cast_array.append_null()
+ }
+ }
+ Arc::new(cast_array.finish()) as ArrayRef
+ }
+ _ => unreachable!("Invalid data type {:?} in cast from string",
to_type),
+ };
+ Ok(cast_array)
+ }
+
fn cast_string_to_timestamp(
array: &ArrayRef,
to_type: &DataType,
@@ -858,7 +918,7 @@ impl Cast {
i32,
"FLOAT",
"INT",
- std::i32::MAX,
+ i32::MAX,
"{:e}"
),
(DataType::Float32, DataType::Int64) => cast_float_to_int32_up!(
@@ -870,7 +930,7 @@ impl Cast {
i64,
"FLOAT",
"BIGINT",
- std::i64::MAX,
+ i64::MAX,
"{:e}"
),
(DataType::Float64, DataType::Int8) => cast_float_to_int16_down!(
@@ -904,7 +964,7 @@ impl Cast {
i32,
"DOUBLE",
"INT",
- std::i32::MAX,
+ i32::MAX,
"{:e}D"
),
(DataType::Float64, DataType::Int64) => cast_float_to_int32_up!(
@@ -916,7 +976,7 @@ impl Cast {
i64,
"DOUBLE",
"BIGINT",
- std::i64::MAX,
+ i64::MAX,
"{:e}D"
),
(DataType::Decimal128(precision, scale), DataType::Int8) => {
@@ -936,7 +996,7 @@ impl Cast {
Int32Array,
i32,
"INT",
- std::i32::MAX,
+ i32::MAX,
*precision,
*scale
)
@@ -948,7 +1008,7 @@ impl Cast {
Int64Array,
i64,
"BIGINT",
- std::i64::MAX,
+ i64::MAX,
*precision,
*scale
)
@@ -1264,15 +1324,15 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode)
-> CometResult<Option<i64>
}
if timestamp.is_none() {
- if eval_mode == EvalMode::Ansi {
- return Err(CometError::CastInvalidValue {
+ return if eval_mode == EvalMode::Ansi {
+ Err(CometError::CastInvalidValue {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
- });
+ })
} else {
- return Ok(None);
- }
+ Ok(None)
+ };
}
match timestamp {
@@ -1409,13 +1469,136 @@ fn parse_str_to_time_only_timestamp(value: &str) ->
CometResult<Option<i64>> {
Ok(Some(timestamp))
}
+//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate.
+fn date_parser(date_str: &str, eval_mode: EvalMode) ->
CometResult<Option<i32>> {
+ // local functions
+ fn get_trimmed_start(bytes: &[u8]) -> usize {
+ let mut start = 0;
+ while start < bytes.len() &&
is_whitespace_or_iso_control(bytes[start]) {
+ start += 1;
+ }
+ start
+ }
+
+ fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize {
+ let mut end = bytes.len() - 1;
+ while end > start && is_whitespace_or_iso_control(bytes[end]) {
+ end -= 1;
+ }
+ end + 1
+ }
+
+ fn is_whitespace_or_iso_control(byte: u8) -> bool {
+ byte.is_ascii_whitespace() || byte.is_ascii_control()
+ }
+
+ fn is_valid_digits(segment: i32, digits: usize) -> bool {
+ // An integer is able to represent a date within [+-]5 million years.
+ let max_digits_year = 7;
+ //year (segment 0) can be between 4 to 7 digits,
+ //month and day (segment 1 and 2) can be between 1 to 2 digits
+ (segment == 0 && digits >= 4 && digits <= max_digits_year)
+ || (segment != 0 && digits > 0 && digits <= 2)
+ }
+
+ fn return_result(date_str: &str, eval_mode: EvalMode) ->
CometResult<Option<i32>> {
+ if eval_mode == EvalMode::Ansi {
+ Err(CometError::CastInvalidValue {
+ value: date_str.to_string(),
+ from_type: "STRING".to_string(),
+ to_type: "DATE".to_string(),
+ })
+ } else {
+ Ok(None)
+ }
+ }
+ // end local functions
+
+ if date_str.is_empty() {
+ return return_result(date_str, eval_mode);
+ }
+
+ //values of date segments year, month and day defaulting to 1
+ let mut date_segments = [1, 1, 1];
+ let mut sign = 1;
+ let mut current_segment = 0;
+ let mut current_segment_value = 0;
+ let mut current_segment_digits = 0;
+ let bytes = date_str.as_bytes();
+
+ let mut j = get_trimmed_start(bytes);
+ let str_end_trimmed = get_trimmed_end(j, bytes);
+
+ if j == str_end_trimmed {
+ return return_result(date_str, eval_mode);
+ }
+
+ //assign a sign to the date
+ if bytes[j] == b'-' || bytes[j] == b'+' {
+ sign = if bytes[j] == b'-' { -1 } else { 1 };
+ j += 1;
+ }
+
+ //loop to the end of string until we have processed 3 segments,
+ //exit loop on encountering any space ' ' or 'T' after the 3rd segment
+ while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' ||
bytes[j] == b'T')) {
+ let b = bytes[j];
+ if current_segment < 2 && b == b'-' {
+ //check for validity of year and month segments if current byte is
separator
+ if !is_valid_digits(current_segment, current_segment_digits) {
+ return return_result(date_str, eval_mode);
+ }
+ //if valid update corresponding segment with the current segment
value.
+ date_segments[current_segment as usize] = current_segment_value;
+ current_segment_value = 0;
+ current_segment_digits = 0;
+ current_segment += 1;
+ } else if !b.is_ascii_digit() {
+ return return_result(date_str, eval_mode);
+ } else {
+ //increment value of current segment by the next digit
+ let parsed_value = (b - b'0') as i32;
+ current_segment_value = current_segment_value * 10 + parsed_value;
+ current_segment_digits += 1;
+ }
+ j += 1;
+ }
+
+ //check for validity of last segment
+ if !is_valid_digits(current_segment, current_segment_digits) {
+ return return_result(date_str, eval_mode);
+ }
+
+ if current_segment < 2 && j < str_end_trimmed {
+ // For the `yyyy` and `yyyy-[m]m` formats, entire input must be
consumed.
+ return return_result(date_str, eval_mode);
+ }
+
+ date_segments[current_segment as usize] = current_segment_value;
+
+ match NaiveDate::from_ymd_opt(
+ sign * date_segments[0],
+ date_segments[1] as u32,
+ date_segments[2] as u32,
+ ) {
+ Some(date) => {
+ let duration_since_epoch = date
+ .signed_duration_since(NaiveDateTime::UNIX_EPOCH.date())
+ .num_days();
+ Ok(Some(duration_since_epoch.to_i32().unwrap()))
+ }
+ None => Ok(None),
+ }
+}
+
#[cfg(test)]
mod tests {
- use super::*;
use arrow::datatypes::TimestampMicrosecondType;
use arrow_array::StringArray;
use arrow_schema::TimeUnit;
+ use super::*;
+
#[test]
fn timestamp_parser_test() {
// write for all formats
@@ -1480,6 +1663,168 @@ mod tests {
assert_eq!(result.len(), 2);
}
+ #[test]
+ fn date_parser_test() {
+ for date in &[
+ "2020",
+ "2020-01",
+ "2020-01-01",
+ "02020-01-01",
+ "002020-01-01",
+ "0002020-01-01",
+ "2020-1-1",
+ "2020-01-01 ",
+ "2020-01-01T",
+ ] {
+ for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi,
EvalMode::Try] {
+ assert_eq!(date_parser(*date, *eval_mode).unwrap(),
Some(18262));
+ }
+ }
+
+ //dates in invalid formats
+ for date in &[
+ "abc",
+ "",
+ "not_a_date",
+ "3/",
+ "3/12",
+ "3/12/2020",
+ "3/12/2002 T",
+ "202",
+ "2020-010-01",
+ "2020-10-010",
+ "2020-10-010T",
+ "--262143-12-31",
+ "--262143-12-31 ",
+ ] {
+ for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
+ assert_eq!(date_parser(*date, *eval_mode).unwrap(), None);
+ }
+ assert!(date_parser(*date, EvalMode::Ansi).is_err());
+ }
+
+ for date in &["-3638-5"] {
+ for eval_mode in &[EvalMode::Legacy, EvalMode::Try,
EvalMode::Ansi] {
+ assert_eq!(date_parser(*date, *eval_mode).unwrap(),
Some(-2048160));
+ }
+ }
+
+ //Naive Date only supports years 262142 AD to 262143 BC
+ //returns None for dates out of range supported by Naive Date.
+ for date in &[
+ "-262144-1-1",
+ "262143-01-1",
+ "262143-1-1",
+ "262143-01-1 ",
+ "262143-01-01T ",
+ "262143-1-01T 1234",
+ "-0973250",
+ ] {
+ for eval_mode in &[EvalMode::Legacy, EvalMode::Try,
EvalMode::Ansi] {
+ assert_eq!(date_parser(*date, *eval_mode).unwrap(), None);
+ }
+ }
+ }
+
+ #[test]
+ fn test_cast_string_to_date() {
+ let array: ArrayRef = Arc::new(StringArray::from(vec![
+ Some("2020"),
+ Some("2020-01"),
+ Some("2020-01-01"),
+ Some("2020-01-01T"),
+ ]));
+
+ let result =
+ Cast::cast_string_to_date(&array, &DataType::Date32,
EvalMode::Legacy).unwrap();
+
+ let date32_array = result
+ .as_any()
+ .downcast_ref::<arrow::array::Date32Array>()
+ .unwrap();
+ assert_eq!(date32_array.len(), 4);
+ date32_array
+ .iter()
+ .for_each(|v| assert_eq!(v.unwrap(), 18262));
+ }
+
+ #[test]
+ fn test_cast_string_array_with_valid_dates() {
+ let array_with_invalid_date: ArrayRef =
Arc::new(StringArray::from(vec![
+ Some("-262143-12-31"),
+ Some("\n -262143-12-31 "),
+ Some("-262143-12-31T \t\n"),
+ Some("\n\t-262143-12-31T\r"),
+ Some("-262143-12-31T 123123123"),
+ Some("\r\n-262143-12-31T \r123123123"),
+ Some("\n -262143-12-31T \n\t"),
+ ]));
+
+ for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] {
+ let result =
+ Cast::cast_string_to_date(&array_with_invalid_date,
&DataType::Date32, *eval_mode)
+ .unwrap();
+
+ let date32_array = result
+ .as_any()
+ .downcast_ref::<arrow::array::Date32Array>()
+ .unwrap();
+ assert_eq!(result.len(), 7);
+ date32_array
+ .iter()
+ .for_each(|v| assert_eq!(v.unwrap(), -96464928));
+ }
+ }
+
+ #[test]
+ fn test_cast_string_array_with_invalid_dates() {
+ let array_with_invalid_date: ArrayRef =
Arc::new(StringArray::from(vec![
+ Some("2020"),
+ Some("2020-01"),
+ Some("2020-01-01"),
+ //4 invalid dates
+ Some("2020-010-01T"),
+ Some("202"),
+ Some(" 202 "),
+ Some("\n 2020-\r8 "),
+ Some("2020-01-01T"),
+ ]));
+
+ for eval_mode in &[EvalMode::Legacy, EvalMode::Try] {
+ let result =
+ Cast::cast_string_to_date(&array_with_invalid_date,
&DataType::Date32, *eval_mode)
+ .unwrap();
+
+ let date32_array = result
+ .as_any()
+ .downcast_ref::<arrow::array::Date32Array>()
+ .unwrap();
+ assert_eq!(
+ date32_array.iter().collect::<Vec<_>>(),
+ vec![
+ Some(18262),
+ Some(18262),
+ Some(18262),
+ None,
+ None,
+ None,
+ None,
+ Some(18262)
+ ]
+ );
+ }
+
+ let result =
+ Cast::cast_string_to_date(&array_with_invalid_date,
&DataType::Date32, EvalMode::Ansi);
+ match result {
+ Err(e) => assert!(
+ e.to_string().contains(
+ "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type
\"STRING\" cannot be cast to \"DATE\" because it is malformed")
+ ),
+ _ => panic!("Expected error"),
+ }
+ }
+
#[test]
fn test_cast_string_as_i8() {
// basic
diff --git a/docs/source/user-guide/compatibility.md
b/docs/source/user-guide/compatibility.md
index a4ed9289..a16fd1b2 100644
--- a/docs/source/user-guide/compatibility.md
+++ b/docs/source/user-guide/compatibility.md
@@ -115,6 +115,7 @@ The following cast operations are generally compatible with
Spark except for the
| string | integer | |
| string | long | |
| string | binary | |
+| string | date | Only supports years between 262143 BC and 262142 AD |
| date | string | |
| timestamp | long | |
| timestamp | decimal | |
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 795bdb42..11c5a53c 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -119,7 +119,7 @@ object CometCast {
Unsupported
case DataTypes.DateType =>
// https://github.com/apache/datafusion-comet/issues/327
- Unsupported
+ Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
case DataTypes.TimestampType if evalMode == "ANSI" =>
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 8caba14c..1710090e 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -22,6 +22,7 @@ package org.apache.comet
import java.io.File
import scala.util.Random
+import scala.util.matching.Regex
import org.apache.spark.sql.{CometTestBase, DataFrame, SaveMode}
import org.apache.spark.sql.catalyst.expressions.Cast
@@ -33,6 +34,7 @@ import org.apache.spark.sql.types.{DataType, DataTypes,
DecimalType}
import org.apache.comet.expressions.{CometCast, Compatible}
class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
+
import testImplicits._
/** Create a data generator using a fixed seed so that tests are
reproducible */
@@ -53,6 +55,7 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
private val numericPattern = "0123456789deEf+-." + whitespaceChars
private val datePattern = "0123456789/" + whitespaceChars
+
private val timestampPattern = "0123456789/:T" + whitespaceChars
test("all valid cast combinations covered") {
@@ -567,9 +570,68 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"),
DataTypes.BinaryType)
}
- ignore("cast StringType to DateType") {
- // https://github.com/apache/datafusion-comet/issues/327
- castTest(gen.generateStrings(dataSize, datePattern, 8).toDF("a"),
DataTypes.DateType)
+ test("cast StringType to DateType") {
+ // error message for invalid dates in Spark 3.2 not supported by Comet see
below issue.
+ // https://github.com/apache/datafusion-comet/issues/440
+ assume(CometSparkSessionExtensions.isSpark33Plus)
+ val validDates = Seq(
+ "262142-01-01",
+ "262142-01-01 ",
+ "262142-01-01T ",
+ "262142-01-01T 123123123",
+ "-262143-12-31",
+ "-262143-12-31 ",
+ "-262143-12-31T",
+ "-262143-12-31T ",
+ "-262143-12-31T 123123123",
+ "2020",
+ "2020-1",
+ "2020-1-1",
+ "2020-01",
+ "2020-01-01",
+ "2020-1-01 ",
+ "2020-01-1",
+ "02020-01-01",
+ "2020-01-01T",
+ "2020-10-01T 1221213",
+ "002020-01-01 ",
+ "0002020-01-01 123344",
+ "-3638-5")
+ val invalidDates = Seq(
+ "0",
+ "202",
+ "3/",
+ "3/3/",
+ "3/3/2020",
+ "3#3#2020",
+ "2020-010-01",
+ "2020-10-010",
+ "2020-10-010T",
+ "--262143-12-31",
+ "--262143-12-31T 1234 ",
+ "abc-def-ghi",
+ "abc-def-ghi jkl",
+ "2020-mar-20",
+ "not_a_date",
+ "T2",
+ "\t\n3938\n8",
+ "8701\t",
+ "\n8757",
+ "7593\t\t\t",
+ "\t9374 \n ",
+ "\n 9850 \t",
+ "\r\n\t9840",
+ "\t9629\n",
+ "\r\n 9629 \r\n",
+ "\r\n 962 \r\n",
+ "\r\n 62 \r\n")
+
+ // due to limitations of NaiveDate we only support years between 262143 BC
and 262142 AD"
+ val unsupportedYearPattern: Regex = "^\\s*[0-9]{5,}".r
+ val fuzzDates = gen
+ .generateStrings(dataSize, datePattern, 8)
+ .filterNot(str => unsupportedYearPattern.findFirstMatchIn(str).isDefined)
+ castTest((validDates ++ invalidDates ++ fuzzDates).toDF("a"),
DataTypes.DateType)
}
test("cast StringType to TimestampType disabled by default") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]