This is an automated email from the ASF dual-hosted git repository.
parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 4d2c398d0 fix: fix string to timestamp cast for UTC timestamps (#3656)
4d2c398d0 is described below
commit 4d2c398d06c25226f29a67fd5ad96ff68a419eeb
Author: Parth Chandra <[email protected]>
AuthorDate: Mon Mar 23 12:55:35 2026 -0700
fix: fix string to timestamp cast for UTC timestamps (#3656)
* fix: fix string to timestamp cast for UTC timestamps
---
native/common/src/error.rs | 28 +++-
native/spark-expr/src/conversion_funcs/string.rs | 171 ++++++++++++---------
.../org/apache/comet/SparkErrorConverter.scala | 2 +-
.../org/apache/comet/expressions/CometCast.scala | 3 -
.../sql/comet/shims/ShimSparkErrorConverter.scala | 18 ++-
.../sql/comet/shims/ShimSparkErrorConverter.scala | 18 ++-
.../sql/comet/shims/ShimSparkErrorConverter.scala | 7 +
.../scala/org/apache/comet/CometCastSuite.scala | 118 +++++++++++---
8 files changed, 264 insertions(+), 101 deletions(-)
diff --git a/native/common/src/error.rs b/native/common/src/error.rs
index 77d4df6cd..4c01b0e69 100644
--- a/native/common/src/error.rs
+++ b/native/common/src/error.rs
@@ -32,6 +32,18 @@ pub enum SparkError {
to_type: String,
},
+ /// Like CastInvalidValue but maps to SparkDateTimeException instead of
SparkNumberFormatException.
+ /// Used for string → timestamp/date cast failures.
+ #[error("[CAST_INVALID_INPUT] The value '{value}' of the type
\"{from_type}\" cannot be cast to \"{to_type}\" \
+ because it is malformed. Correct the value as per the syntax, or
change its target type. \
+ Use `try_cast` to tolerate malformed input and return NULL instead. If
necessary \
+ set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+ InvalidInputInCastToDatetime {
+ value: String,
+ from_type: String,
+ to_type: String,
+ },
+
#[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be
represented as Decimal({precision}, {scale}). If necessary set
\"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL
instead.")]
NumericValueOutOfRange {
value: String,
@@ -208,6 +220,7 @@ impl SparkError {
pub(crate) fn error_type_name(&self) -> &'static str {
match self {
SparkError::CastInvalidValue { .. } => "CastInvalidValue",
+ SparkError::InvalidInputInCastToDatetime { .. } =>
"InvalidInputInCastToDatetime",
SparkError::NumericValueOutOfRange { .. } =>
"NumericValueOutOfRange",
SparkError::NumericOutOfRange { .. } => "NumericOutOfRange",
SparkError::CastOverFlow { .. } => "CastOverFlow",
@@ -266,6 +279,17 @@ impl SparkError {
"toType": to_type,
})
}
+ SparkError::InvalidInputInCastToDatetime {
+ value,
+ from_type,
+ to_type,
+ } => {
+ serde_json::json!({
+ "value": value,
+ "fromType": from_type,
+ "toType": to_type,
+ })
+ }
SparkError::NumericValueOutOfRange {
value,
precision,
@@ -505,7 +529,8 @@ impl SparkError {
| SparkError::ScalarSubqueryTooManyRows =>
"org/apache/spark/SparkRuntimeException",
// DateTimeException
- SparkError::CannotParseTimestamp { .. }
+ SparkError::InvalidInputInCastToDatetime { .. }
+ | SparkError::CannotParseTimestamp { .. }
| SparkError::InvalidFractionOfSecond { .. } =>
"org/apache/spark/SparkDateTimeException",
// IllegalArgumentException
@@ -530,6 +555,7 @@ impl SparkError {
match self {
// Cast errors
SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"),
+ SparkError::InvalidInputInCastToDatetime { .. } =>
Some("CAST_INVALID_INPUT"),
SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"),
SparkError::NumericValueOutOfRange { .. } => {
Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION")
diff --git a/native/spark-expr/src/conversion_funcs/string.rs
b/native/spark-expr/src/conversion_funcs/string.rs
index 7c193716d..cd1c643be 100644
--- a/native/spark-expr/src/conversion_funcs/string.rs
+++ b/native/spark-expr/src/conversion_funcs/string.rs
@@ -31,25 +31,35 @@ use num::{CheckedSub, Integer};
use regex::Regex;
use std::num::Wrapping;
use std::str::FromStr;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
macro_rules! cast_utf8_to_timestamp {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident,
$tz:expr) => {{
let len = $array.len();
let mut cast_array =
PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
+ let mut cast_err: Option<SparkError> = None;
for i in 0..len {
if $array.is_null(i) {
cast_array.append_null()
- } else if let Ok(Some(cast_value)) =
- $cast_method($array.value(i).trim(), $eval_mode, $tz)
- {
- cast_array.append_value(cast_value);
} else {
- cast_array.append_null()
+ match $cast_method($array.value(i).trim(), $eval_mode, $tz) {
+ Ok(Some(cast_value)) =>
cast_array.append_value(cast_value),
+ Ok(None) => cast_array.append_null(),
+ Err(e) => {
+ if $eval_mode == EvalMode::Ansi {
+ cast_err = Some(e);
+ break;
+ }
+ cast_array.append_null()
+ }
+ }
}
}
- let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
- result
+ if let Some(e) = cast_err {
+ Err(e)
+ } else {
+ Ok(Arc::new(cast_array.finish()) as ArrayRef)
+ }
}};
}
@@ -668,15 +678,13 @@ pub(crate) fn cast_string_to_timestamp(
let tz = &timezone::Tz::from_str(timezone_str).unwrap();
let cast_array: ArrayRef = match to_type {
- DataType::Timestamp(_, _) => {
- cast_utf8_to_timestamp!(
- string_array,
- eval_mode,
- TimestampMicrosecondType,
- timestamp_parser,
- tz
- )
- }
+ DataType::Timestamp(_, _) => cast_utf8_to_timestamp!(
+ string_array,
+ eval_mode,
+ TimestampMicrosecondType,
+ timestamp_parser,
+ tz
+ )?,
_ => unreachable!("Invalid data type {:?} in cast from string",
to_type),
};
Ok(cast_array)
@@ -961,6 +969,12 @@ fn get_timestamp_values<T: TimeZone>(
) -> SparkResult<Option<i64>> {
let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
let year = values[0].parse::<i32>().unwrap_or_default();
+
+ // NaiveDate (used internally by chrono's with_ymd_and_hms) is bounded to
±262142.
+ if !(-262143..=262142).contains(&year) {
+ return Ok(None);
+ }
+
let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
@@ -1004,7 +1018,7 @@ fn get_timestamp_values<T: TimeZone>(
.with_second(second)
.with_microsecond(microsecond),
_ => {
- return Err(SparkError::CastInvalidValue {
+ return Err(SparkError::InvalidInputInCastToDatetime {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
@@ -1082,7 +1096,21 @@ fn parse_str_to_microsecond_timestamp<T: TimeZone>(
get_timestamp_values(value, "microsecond", tz)
}
-// used in tests only
+type TimestampPattern<T> = (&'static Regex, fn(&str, &T) ->
SparkResult<Option<i64>>);
+
+static RE_YEAR: LazyLock<Regex> = LazyLock::new(||
Regex::new(r"^\d{4,7}$").unwrap());
+static RE_MONTH: LazyLock<Regex> = LazyLock::new(||
Regex::new(r"^\d{4,7}-\d{2}$").unwrap());
+static RE_DAY: LazyLock<Regex> = LazyLock::new(||
Regex::new(r"^\d{4,7}-\d{2}-\d{2}$").unwrap());
+static RE_HOUR: LazyLock<Regex> =
+ LazyLock::new(|| Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{1,2}$").unwrap());
+static RE_MINUTE: LazyLock<Regex> =
+ LazyLock::new(||
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap());
+static RE_SECOND: LazyLock<Regex> =
+ LazyLock::new(||
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap());
+static RE_MICROSECOND: LazyLock<Regex> =
+ LazyLock::new(||
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap());
+static RE_TIME_ONLY: LazyLock<Regex> = LazyLock::new(||
Regex::new(r"^T\d{1,2}$").unwrap());
+
fn timestamp_parser<T: TimeZone>(
value: &str,
eval_mode: EvalMode,
@@ -1092,40 +1120,15 @@ fn timestamp_parser<T: TimeZone>(
if value.is_empty() {
return Ok(None);
}
- // Define regex patterns and corresponding parsing functions
- let patterns = &[
- (
- Regex::new(r"^\d{4,5}$").unwrap(),
- parse_str_to_year_timestamp as fn(&str, &T) ->
SparkResult<Option<i64>>,
- ),
- (
- Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
- parse_str_to_month_timestamp,
- ),
- (
- Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
- parse_str_to_day_timestamp,
- ),
- (
- Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
- parse_str_to_hour_timestamp,
- ),
- (
- Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
- parse_str_to_minute_timestamp,
- ),
- (
- Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
- parse_str_to_second_timestamp,
- ),
- (
-
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
- parse_str_to_microsecond_timestamp,
- ),
- (
- Regex::new(r"^T\d{1,2}$").unwrap(),
- parse_str_to_time_only_timestamp,
- ),
+ let patterns: &[TimestampPattern<T>] = &[
+ (&RE_YEAR, parse_str_to_year_timestamp),
+ (&RE_MONTH, parse_str_to_month_timestamp),
+ (&RE_DAY, parse_str_to_day_timestamp),
+ (&RE_HOUR, parse_str_to_hour_timestamp),
+ (&RE_MINUTE, parse_str_to_minute_timestamp),
+ (&RE_SECOND, parse_str_to_second_timestamp),
+ (&RE_MICROSECOND, parse_str_to_microsecond_timestamp),
+ (&RE_TIME_ONLY, parse_str_to_time_only_timestamp),
];
let mut timestamp = None;
@@ -1140,7 +1143,7 @@ fn timestamp_parser<T: TimeZone>(
if timestamp.is_none() {
return if eval_mode == EvalMode::Ansi {
- Err(SparkError::CastInvalidValue {
+ Err(SparkError::InvalidInputInCastToDatetime {
value: value.to_string(),
from_type: "STRING".to_string(),
to_type: "TIMESTAMP".to_string(),
@@ -1150,12 +1153,7 @@ fn timestamp_parser<T: TimeZone>(
};
}
- match timestamp {
- Some(ts) => Ok(Some(ts)),
- None => Err(SparkError::Internal(
- "Failed to parse timestamp".to_string(),
- )),
- }
+ Ok(timestamp)
}
fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) ->
SparkResult<Option<i64>> {
@@ -1202,17 +1200,20 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i32>>
}
fn is_valid_digits(segment: i32, digits: usize) -> bool {
- // An integer is able to represent a date within [+-]5 million years.
+ // NaiveDate is bounded to [-262142, 262142] (6 digits). We allow up
to 7 digits to support
+ // leading-zero year strings like "0002020" (= year 2020), matching
Spark's
+ // isValidDigits. Values outside the bounds are caught by an explicit
bounds
+ // check below.
let max_digits_year = 7;
- //year (segment 0) can be between 4 to 7 digits,
- //month and day (segment 1 and 2) can be between 1 to 2 digits
+ // year (segment 0) can be between 4 to 7 digits,
+ // month and day (segment 1 and 2) can be between 1 to 2 digits
(segment == 0 && digits >= 4 && digits <= max_digits_year)
|| (segment != 0 && digits > 0 && digits <= 2)
}
fn return_result(date_str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i32>> {
if eval_mode == EvalMode::Ansi {
- Err(SparkError::CastInvalidValue {
+ Err(SparkError::InvalidInputInCastToDatetime {
value: date_str.to_string(),
from_type: "STRING".to_string(),
to_type: "DATE".to_string(),
@@ -1285,11 +1286,13 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) ->
SparkResult<Option<i32>>
date_segments[current_segment as usize] = current_segment_value.0;
- match NaiveDate::from_ymd_opt(
- sign * date_segments[0],
- date_segments[1] as u32,
- date_segments[2] as u32,
- ) {
+ // Reject out-of-range years explicitly
+ let year = sign * date_segments[0];
+ if !(-262143..=262142).contains(&year) {
+ return Ok(None);
+ }
+
+ match NaiveDate::from_ymd_opt(year, date_segments[1] as u32,
date_segments[2] as u32) {
Some(date) => {
let duration_since_epoch = date
.signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
@@ -1341,7 +1344,8 @@ mod tests {
TimestampMicrosecondType,
timestamp_parser,
tz
- );
+ )
+ .unwrap();
assert_eq!(
result.data_type(),
@@ -1350,6 +1354,33 @@ mod tests {
assert_eq!(result.len(), 4);
}
+ #[test]
+ fn test_cast_string_to_timestamp_ansi_error() {
+ // In ANSI mode, an invalid timestamp string must produce an error
rather than null.
+ let array: ArrayRef = Arc::new(StringArray::from(vec![
+ Some("2020-01-01T12:34:56.123456"),
+ Some("not_a_timestamp"),
+ ]));
+ let tz = &timezone::Tz::from_str("UTC").unwrap();
+ let string_array = array
+ .as_any()
+ .downcast_ref::<GenericStringArray<i32>>()
+ .expect("Expected a string array");
+
+ let eval_mode = EvalMode::Ansi;
+ let result = cast_utf8_to_timestamp!(
+ &string_array,
+ eval_mode,
+ TimestampMicrosecondType,
+ timestamp_parser,
+ tz
+ );
+ assert!(
+ result.is_err(),
+ "ANSI mode should return Err for an invalid timestamp string"
+ );
+ }
+
#[test]
fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
// prepare input data
diff --git a/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
b/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
index a8dea4cf4..284671896 100644
--- a/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
+++ b/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
@@ -100,7 +100,7 @@ object SparkErrorConverter extends ShimSparkErrorConverter {
case None => Array.empty[QueryContext] // No context
}
- val summary: String = errorJson.summary.orNull
+ val summary: String = errorJson.summary.getOrElse("")
// Delegate to version-specific shim - let conversion exceptions propagate
val optEx = convertErrorType(errorJson.errorType, errorClass, params,
sparkContext, summary)
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 95d536690..d50aa5d8d 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -217,10 +217,7 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
- case DataTypes.TimestampType if evalMode == CometEvalMode.ANSI =>
- Incompatible(Some("ANSI mode not supported"))
case DataTypes.TimestampType =>
- // https://github.com/apache/datafusion-comet/issues/328
Incompatible(Some("Not all valid formats are supported"))
case _ =>
unsupported(DataTypes.StringType, toType)
diff --git
a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index e4ec9e006..6eee3f5bc 100644
---
a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++
b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -23,7 +23,7 @@ import java.io.FileNotFoundException
import scala.util.matching.Regex
-import org.apache.spark.{QueryContext, SparkException}
+import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -172,6 +172,22 @@ trait ShimSparkErrorConverter {
QueryExecutionErrors
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
+ case "InvalidInputInCastToDatetime" =>
+ val expression =
+ s"'${params("value").toString.replace("\\", "\\\\").replace("'",
"\\'")}'"
+ val sourceType = s""""${params("fromType").toString}""""
+ val targetType = s""""${params("toType").toString}""""
+ Some(
+ new SparkDateTimeException(
+ errorClass = "CAST_INVALID_INPUT",
+ messageParameters = Map(
+ "expression" -> expression,
+ "sourceType" -> sourceType,
+ "targetType" -> targetType,
+ "ansiConfig" -> "\"spark.sql.ansi.enabled\""),
+ context = context,
+ summary = summary))
+
case "CastOverFlow" =>
val fromType = getDataType(params("fromType").toString)
val toType = getDataType(params("toType").toString)
diff --git
a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index 41f461100..75316c51e 100644
---
a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++
b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -23,7 +23,7 @@ import java.io.FileNotFoundException
import scala.util.matching.Regex
-import org.apache.spark.{QueryContext, SparkException}
+import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -170,6 +170,22 @@ trait ShimSparkErrorConverter {
QueryExecutionErrors
.invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
+ case "InvalidInputInCastToDatetime" =>
+ val expression =
+ s"'${params("value").toString.replace("\\", "\\\\").replace("'",
"\\'")}'"
+ val sourceType = s""""${params("fromType").toString}""""
+ val targetType = s""""${params("toType").toString}""""
+ Some(
+ new SparkDateTimeException(
+ errorClass = "CAST_INVALID_INPUT",
+ messageParameters = Map(
+ "expression" -> expression,
+ "sourceType" -> sourceType,
+ "targetType" -> targetType,
+ "ansiConfig" -> "\"spark.sql.ansi.enabled\""),
+ context = context,
+ summary = summary))
+
case "CastOverFlow" =>
val fromType = getDataType(params("fromType").toString)
val toType = getDataType(params("toType").toString)
diff --git
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index f906db140..fc13a58a4 100644
---
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -182,6 +182,13 @@ trait ShimSparkErrorConverter {
QueryExecutionErrors
.invalidInputInCastToNumberError(targetType, str,
context.headOption.orNull))
+ case "InvalidInputInCastToDatetime" =>
+ val str = UTF8String.fromString(params("value").toString)
+ val targetType = getDataType(params("toType").toString)
+ Some(
+ QueryExecutionErrors
+ .invalidInputInCastToDatetimeError(str, targetType,
context.headOption.orNull))
+
case "CastOverFlow" =>
val fromType = getDataType(params("fromType").toString)
val toType = getDataType(params("toType").toString)
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 3d9acc39e..f213e90e0 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -23,7 +23,6 @@ import java.io.File
import scala.collection.mutable.ListBuffer
import scala.util.Random
-import scala.util.matching.Regex
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{CometTestBase, DataFrame, Row, SaveMode}
@@ -916,11 +915,15 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
"\r\n 962 \r\n",
"\r\n 62 \r\n")
- // due to limitations of NaiveDate we only support years between 262143 BC
and 262142 AD"
- val unsupportedYearPattern: Regex = "^\\s*[0-9]{5,}".r
+ // due to limitations of NaiveDate we only support years between 262143 BC
and 262142 AD
+ // Filter out strings where the leading digit sequence represents a year >
262142.
+ // All 5-digit years (10000-99999) are within bounds; only 6-digit years
may exceed the limit.
val fuzzDates = gen
.generateStrings(dataSize, datePattern, 8)
- .filterNot(str => unsupportedYearPattern.findFirstMatchIn(str).isDefined)
+ .filterNot { str =>
+ val yearStr = str.trim.takeWhile(_.isDigit)
+ yearStr.length > 6 || (yearStr.length == 6 && yearStr.toInt > 262142)
+ }
castTest((validDates ++ invalidDates ++ fuzzDates).toDF("a"),
DataTypes.DateType)
}
@@ -951,19 +954,47 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
- test("cast StringType to TimestampType disabled by default") {
- withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC")) {
- val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a")
- castFallbackTest(
- values.toDF("a"),
- DataTypes.TimestampType,
- "Not all valid formats are supported")
+ test("cast StringType to TimestampType - UTC") {
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+ val values = Seq(
+ "2020",
+ "2020-01",
+ "2020-01-01",
+ "2020-01-01T12",
+ "2020-01-01T12:34",
+ "2020-01-01T12:34:56",
+ "2020-01-01T12:34:56.123456",
+ "T2",
+ "-9?",
+ "0100",
+ "0100-01",
+ "0100-01-01",
+ "0100-01-01T12",
+ "0100-01-01T12:34",
+ "0100-01-01T12:34:56",
+ "0100-01-01T12:34:56.123456",
+ "10000",
+ "10000-01",
+ "10000-01-01",
+ "10000-01-01T12",
+ "10000-01-01T12:34",
+ "10000-01-01T12:34:56",
+ "10000-01-01T12:34:56.123456",
+ "213170",
+ "213170-06",
+ "213170-06-15",
+ "213170-06-15T12",
+ "213170-06-15T12:34",
+ "213170-06-15T12:34:56",
+ "213170-06-15T12:34:56.123456")
+ castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
}
}
ignore("cast StringType to TimestampType") {
- // https://github.com/apache/datafusion-comet/issues/328
- withSQLConf((CometConf.getExprAllowIncompatConfigKey(classOf[Cast]),
"true")) {
+ // TODO: enable once all Spark timestamp formats are supported natively.
+ // Currently missing: time-only formats with colon (e.g. "T12:34", "4:4").
+ withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC")) {
val values = Seq("2020-01-01T12:34:56.123456", "T2") ++
gen.generateStrings(
dataSize,
timestampPattern,
@@ -994,6 +1025,13 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
"2020-01-01T12:34:56.123456",
"T2",
"-9?",
+ "100",
+ "100-01",
+ "100-01-01",
+ "100-01-01T12",
+ "100-01-01T12:34",
+ "100-01-01T12:34:56",
+ "100-01-01T12:34:56.123456",
"0100",
"0100-01",
"0100-01-01",
@@ -1010,14 +1048,6 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
"10000-01-01T12:34:56.123456")
castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
}
-
- // test for invalid inputs
- withSQLConf(
- SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC",
- CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
- val values = Seq("-9?", "1-", "0.5")
- castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
- }
}
// CAST from BinaryType
@@ -1476,7 +1506,10 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
- private def castTimestampTest(input: DataFrame, toType: DataType) = {
+ private def castTimestampTest(
+ input: DataFrame,
+ toType: DataType,
+ assertNative: Boolean = false) = {
withTempPath { dir =>
val data = roundtripParquet(input, dir).coalesce(1)
data.createOrReplaceTempView("t")
@@ -1484,12 +1517,47 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
// cast() should return null for invalid inputs when ansi mode is
disabled
val df = data.withColumn("converted", col("a").cast(toType))
- checkSparkAnswer(df)
+ if (assertNative) {
+ checkSparkAnswerAndOperator(df)
+ } else {
+ checkSparkAnswer(df)
+ }
// try_cast() should always return null for invalid inputs
val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")
checkSparkAnswer(df2)
}
+
+ // with ANSI enabled, we should produce the same exception as Spark
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ val df = data.withColumn("converted", col("a").cast(toType))
+ checkSparkAnswerMaybeThrows(df) match {
+ case (None, None) =>
+ // both succeeded, results already compared
+ case (None, Some(e)) =>
+ throw e
+ case (Some(e), None) =>
+ fail(s"Comet should have failed with ${e.getCause.getMessage}")
+ case (Some(sparkException), Some(cometException)) =>
+ val sparkMessage =
+ if (sparkException.getCause != null)
sparkException.getCause.getMessage
+ else sparkException.getMessage
+ val cometMessage =
+ if (cometException.getCause != null)
cometException.getCause.getMessage
+ else cometException.getMessage
+ if (CometSparkSessionExtensions.isSpark40Plus) {
+ assert(sparkMessage.contains("SQLSTATE"))
+ assert(
+ sparkMessage.startsWith(
+ cometMessage.substring(0, math.min(40,
cometMessage.length))))
+ } else {
+ assert(cometMessage == sparkMessage)
+ }
+ }
+ // try_cast()
+ val dfTryCast = spark.sql(s"select try_cast(a as ${toType.sql}) from
t")
+ checkSparkAnswer(dfTryCast)
+ }
}
}
@@ -1575,7 +1643,9 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
// context eagerly so it displays the call site at the
// line of code where the cast method was called, whereas
spark grabs the context
// lazily and displays the call site at the line of code
where the error is checked.
- assert(sparkMessage.startsWith(cometMessage.substring(0,
40)))
+ assert(
+ sparkMessage.startsWith(
+ cometMessage.substring(0, math.min(40,
cometMessage.length))))
} else {
// for Spark 3.4 we expect to reproduce the error message
exactly
assert(cometMessage == sparkMessage)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]