This is an automated email from the ASF dual-hosted git repository.

parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d2c398d0 fix: fix string to timestamp cast for UTC timestamps (#3656)
4d2c398d0 is described below

commit 4d2c398d06c25226f29a67fd5ad96ff68a419eeb
Author: Parth Chandra <[email protected]>
AuthorDate: Mon Mar 23 12:55:35 2026 -0700

    fix: fix string to timestamp cast for UTC timestamps (#3656)
    
    * fix: fix string to timestamp cast for UTC timestamps
---
 native/common/src/error.rs                         |  28 +++-
 native/spark-expr/src/conversion_funcs/string.rs   | 171 ++++++++++++---------
 .../org/apache/comet/SparkErrorConverter.scala     |   2 +-
 .../org/apache/comet/expressions/CometCast.scala   |   3 -
 .../sql/comet/shims/ShimSparkErrorConverter.scala  |  18 ++-
 .../sql/comet/shims/ShimSparkErrorConverter.scala  |  18 ++-
 .../sql/comet/shims/ShimSparkErrorConverter.scala  |   7 +
 .../scala/org/apache/comet/CometCastSuite.scala    | 118 +++++++++++---
 8 files changed, 264 insertions(+), 101 deletions(-)

diff --git a/native/common/src/error.rs b/native/common/src/error.rs
index 77d4df6cd..4c01b0e69 100644
--- a/native/common/src/error.rs
+++ b/native/common/src/error.rs
@@ -32,6 +32,18 @@ pub enum SparkError {
         to_type: String,
     },
 
+    /// Like CastInvalidValue but maps to SparkDateTimeException instead of 
SparkNumberFormatException.
+    /// Used for string → timestamp/date cast failures.
+    #[error("[CAST_INVALID_INPUT] The value '{value}' of the type 
\"{from_type}\" cannot be cast to \"{to_type}\" \
+        because it is malformed. Correct the value as per the syntax, or 
change its target type. \
+        Use `try_cast` to tolerate malformed input and return NULL instead. If 
necessary \
+        set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
+    InvalidInputInCastToDatetime {
+        value: String,
+        from_type: String,
+        to_type: String,
+    },
+
     #[error("[NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION] {value} cannot be 
represented as Decimal({precision}, {scale}). If necessary set 
\"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL 
instead.")]
     NumericValueOutOfRange {
         value: String,
@@ -208,6 +220,7 @@ impl SparkError {
     pub(crate) fn error_type_name(&self) -> &'static str {
         match self {
             SparkError::CastInvalidValue { .. } => "CastInvalidValue",
+            SparkError::InvalidInputInCastToDatetime { .. } => 
"InvalidInputInCastToDatetime",
             SparkError::NumericValueOutOfRange { .. } => 
"NumericValueOutOfRange",
             SparkError::NumericOutOfRange { .. } => "NumericOutOfRange",
             SparkError::CastOverFlow { .. } => "CastOverFlow",
@@ -266,6 +279,17 @@ impl SparkError {
                     "toType": to_type,
                 })
             }
+            SparkError::InvalidInputInCastToDatetime {
+                value,
+                from_type,
+                to_type,
+            } => {
+                serde_json::json!({
+                    "value": value,
+                    "fromType": from_type,
+                    "toType": to_type,
+                })
+            }
             SparkError::NumericValueOutOfRange {
                 value,
                 precision,
@@ -505,7 +529,8 @@ impl SparkError {
             | SparkError::ScalarSubqueryTooManyRows => 
"org/apache/spark/SparkRuntimeException",
 
             // DateTimeException
-            SparkError::CannotParseTimestamp { .. }
+            SparkError::InvalidInputInCastToDatetime { .. }
+            | SparkError::CannotParseTimestamp { .. }
             | SparkError::InvalidFractionOfSecond { .. } => 
"org/apache/spark/SparkDateTimeException",
 
             // IllegalArgumentException
@@ -530,6 +555,7 @@ impl SparkError {
         match self {
             // Cast errors
             SparkError::CastInvalidValue { .. } => Some("CAST_INVALID_INPUT"),
+            SparkError::InvalidInputInCastToDatetime { .. } => 
Some("CAST_INVALID_INPUT"),
             SparkError::CastOverFlow { .. } => Some("CAST_OVERFLOW"),
             SparkError::NumericValueOutOfRange { .. } => {
                 Some("NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION")
diff --git a/native/spark-expr/src/conversion_funcs/string.rs 
b/native/spark-expr/src/conversion_funcs/string.rs
index 7c193716d..cd1c643be 100644
--- a/native/spark-expr/src/conversion_funcs/string.rs
+++ b/native/spark-expr/src/conversion_funcs/string.rs
@@ -31,25 +31,35 @@ use num::{CheckedSub, Integer};
 use regex::Regex;
 use std::num::Wrapping;
 use std::str::FromStr;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
 
 macro_rules! cast_utf8_to_timestamp {
     ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, 
$tz:expr) => {{
         let len = $array.len();
         let mut cast_array = 
PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC");
+        let mut cast_err: Option<SparkError> = None;
         for i in 0..len {
             if $array.is_null(i) {
                 cast_array.append_null()
-            } else if let Ok(Some(cast_value)) =
-                $cast_method($array.value(i).trim(), $eval_mode, $tz)
-            {
-                cast_array.append_value(cast_value);
             } else {
-                cast_array.append_null()
+                match $cast_method($array.value(i).trim(), $eval_mode, $tz) {
+                    Ok(Some(cast_value)) => 
cast_array.append_value(cast_value),
+                    Ok(None) => cast_array.append_null(),
+                    Err(e) => {
+                        if $eval_mode == EvalMode::Ansi {
+                            cast_err = Some(e);
+                            break;
+                        }
+                        cast_array.append_null()
+                    }
+                }
             }
         }
-        let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef;
-        result
+        if let Some(e) = cast_err {
+            Err(e)
+        } else {
+            Ok(Arc::new(cast_array.finish()) as ArrayRef)
+        }
     }};
 }
 
@@ -668,15 +678,13 @@ pub(crate) fn cast_string_to_timestamp(
     let tz = &timezone::Tz::from_str(timezone_str).unwrap();
 
     let cast_array: ArrayRef = match to_type {
-        DataType::Timestamp(_, _) => {
-            cast_utf8_to_timestamp!(
-                string_array,
-                eval_mode,
-                TimestampMicrosecondType,
-                timestamp_parser,
-                tz
-            )
-        }
+        DataType::Timestamp(_, _) => cast_utf8_to_timestamp!(
+            string_array,
+            eval_mode,
+            TimestampMicrosecondType,
+            timestamp_parser,
+            tz
+        )?,
         _ => unreachable!("Invalid data type {:?} in cast from string", 
to_type),
     };
     Ok(cast_array)
@@ -961,6 +969,12 @@ fn get_timestamp_values<T: TimeZone>(
 ) -> SparkResult<Option<i64>> {
     let values: Vec<_> = value.split(['T', '-', ':', '.']).collect();
     let year = values[0].parse::<i32>().unwrap_or_default();
+
+    // NaiveDate (used internally by chrono's with_ymd_and_hms) is bounded to 
±262142.
+    if !(-262143..=262142).contains(&year) {
+        return Ok(None);
+    }
+
     let month = values.get(1).map_or(1, |m| m.parse::<u32>().unwrap_or(1));
     let day = values.get(2).map_or(1, |d| d.parse::<u32>().unwrap_or(1));
     let hour = values.get(3).map_or(0, |h| h.parse::<u32>().unwrap_or(0));
@@ -1004,7 +1018,7 @@ fn get_timestamp_values<T: TimeZone>(
             .with_second(second)
             .with_microsecond(microsecond),
         _ => {
-            return Err(SparkError::CastInvalidValue {
+            return Err(SparkError::InvalidInputInCastToDatetime {
                 value: value.to_string(),
                 from_type: "STRING".to_string(),
                 to_type: "TIMESTAMP".to_string(),
@@ -1082,7 +1096,21 @@ fn parse_str_to_microsecond_timestamp<T: TimeZone>(
     get_timestamp_values(value, "microsecond", tz)
 }
 
-// used in tests only
+type TimestampPattern<T> = (&'static Regex, fn(&str, &T) -> 
SparkResult<Option<i64>>);
+
+static RE_YEAR: LazyLock<Regex> = LazyLock::new(|| 
Regex::new(r"^\d{4,7}$").unwrap());
+static RE_MONTH: LazyLock<Regex> = LazyLock::new(|| 
Regex::new(r"^\d{4,7}-\d{2}$").unwrap());
+static RE_DAY: LazyLock<Regex> = LazyLock::new(|| 
Regex::new(r"^\d{4,7}-\d{2}-\d{2}$").unwrap());
+static RE_HOUR: LazyLock<Regex> =
+    LazyLock::new(|| Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{1,2}$").unwrap());
+static RE_MINUTE: LazyLock<Regex> =
+    LazyLock::new(|| 
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap());
+static RE_SECOND: LazyLock<Regex> =
+    LazyLock::new(|| 
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap());
+static RE_MICROSECOND: LazyLock<Regex> =
+    LazyLock::new(|| 
Regex::new(r"^\d{4,7}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap());
+static RE_TIME_ONLY: LazyLock<Regex> = LazyLock::new(|| 
Regex::new(r"^T\d{1,2}$").unwrap());
+
 fn timestamp_parser<T: TimeZone>(
     value: &str,
     eval_mode: EvalMode,
@@ -1092,40 +1120,15 @@ fn timestamp_parser<T: TimeZone>(
     if value.is_empty() {
         return Ok(None);
     }
-    // Define regex patterns and corresponding parsing functions
-    let patterns = &[
-        (
-            Regex::new(r"^\d{4,5}$").unwrap(),
-            parse_str_to_year_timestamp as fn(&str, &T) -> 
SparkResult<Option<i64>>,
-        ),
-        (
-            Regex::new(r"^\d{4,5}-\d{2}$").unwrap(),
-            parse_str_to_month_timestamp,
-        ),
-        (
-            Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(),
-            parse_str_to_day_timestamp,
-        ),
-        (
-            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(),
-            parse_str_to_hour_timestamp,
-        ),
-        (
-            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(),
-            parse_str_to_minute_timestamp,
-        ),
-        (
-            Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(),
-            parse_str_to_second_timestamp,
-        ),
-        (
-            
Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(),
-            parse_str_to_microsecond_timestamp,
-        ),
-        (
-            Regex::new(r"^T\d{1,2}$").unwrap(),
-            parse_str_to_time_only_timestamp,
-        ),
+    let patterns: &[TimestampPattern<T>] = &[
+        (&RE_YEAR, parse_str_to_year_timestamp),
+        (&RE_MONTH, parse_str_to_month_timestamp),
+        (&RE_DAY, parse_str_to_day_timestamp),
+        (&RE_HOUR, parse_str_to_hour_timestamp),
+        (&RE_MINUTE, parse_str_to_minute_timestamp),
+        (&RE_SECOND, parse_str_to_second_timestamp),
+        (&RE_MICROSECOND, parse_str_to_microsecond_timestamp),
+        (&RE_TIME_ONLY, parse_str_to_time_only_timestamp),
     ];
 
     let mut timestamp = None;
@@ -1140,7 +1143,7 @@ fn timestamp_parser<T: TimeZone>(
 
     if timestamp.is_none() {
         return if eval_mode == EvalMode::Ansi {
-            Err(SparkError::CastInvalidValue {
+            Err(SparkError::InvalidInputInCastToDatetime {
                 value: value.to_string(),
                 from_type: "STRING".to_string(),
                 to_type: "TIMESTAMP".to_string(),
@@ -1150,12 +1153,7 @@ fn timestamp_parser<T: TimeZone>(
         };
     }
 
-    match timestamp {
-        Some(ts) => Ok(Some(ts)),
-        None => Err(SparkError::Internal(
-            "Failed to parse timestamp".to_string(),
-        )),
-    }
+    Ok(timestamp)
 }
 
 fn parse_str_to_time_only_timestamp<T: TimeZone>(value: &str, tz: &T) -> 
SparkResult<Option<i64>> {
@@ -1202,17 +1200,20 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> 
SparkResult<Option<i32>>
     }
 
     fn is_valid_digits(segment: i32, digits: usize) -> bool {
-        // An integer is able to represent a date within [+-]5 million years.
+        // NaiveDate is bounded to [-262142, 262142] (6 digits). We allow up 
to 7 digits to support
+        // leading-zero year strings like "0002020" (= year 2020), matching 
Spark's
+        // isValidDigits. Values outside the bounds are caught by an explicit 
bounds
+        // check below.
         let max_digits_year = 7;
-        //year (segment 0) can be between 4 to 7 digits,
-        //month and day (segment 1 and 2) can be between 1 to 2 digits
+        // year (segment 0) can be between 4 to 7 digits,
+        // month and day (segment 1 and 2) can be between 1 to 2 digits
         (segment == 0 && digits >= 4 && digits <= max_digits_year)
             || (segment != 0 && digits > 0 && digits <= 2)
     }
 
     fn return_result(date_str: &str, eval_mode: EvalMode) -> 
SparkResult<Option<i32>> {
         if eval_mode == EvalMode::Ansi {
-            Err(SparkError::CastInvalidValue {
+            Err(SparkError::InvalidInputInCastToDatetime {
                 value: date_str.to_string(),
                 from_type: "STRING".to_string(),
                 to_type: "DATE".to_string(),
@@ -1285,11 +1286,13 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> 
SparkResult<Option<i32>>
 
     date_segments[current_segment as usize] = current_segment_value.0;
 
-    match NaiveDate::from_ymd_opt(
-        sign * date_segments[0],
-        date_segments[1] as u32,
-        date_segments[2] as u32,
-    ) {
+    // Reject out-of-range years explicitly
+    let year = sign * date_segments[0];
+    if !(-262143..=262142).contains(&year) {
+        return Ok(None);
+    }
+
+    match NaiveDate::from_ymd_opt(year, date_segments[1] as u32, 
date_segments[2] as u32) {
         Some(date) => {
             let duration_since_epoch = date
                 .signed_duration_since(DateTime::UNIX_EPOCH.naive_utc().date())
@@ -1341,7 +1344,8 @@ mod tests {
             TimestampMicrosecondType,
             timestamp_parser,
             tz
-        );
+        )
+        .unwrap();
 
         assert_eq!(
             result.data_type(),
@@ -1350,6 +1354,33 @@ mod tests {
         assert_eq!(result.len(), 4);
     }
 
+    #[test]
+    fn test_cast_string_to_timestamp_ansi_error() {
+        // In ANSI mode, an invalid timestamp string must produce an error 
rather than null.
+        let array: ArrayRef = Arc::new(StringArray::from(vec![
+            Some("2020-01-01T12:34:56.123456"),
+            Some("not_a_timestamp"),
+        ]));
+        let tz = &timezone::Tz::from_str("UTC").unwrap();
+        let string_array = array
+            .as_any()
+            .downcast_ref::<GenericStringArray<i32>>()
+            .expect("Expected a string array");
+
+        let eval_mode = EvalMode::Ansi;
+        let result = cast_utf8_to_timestamp!(
+            &string_array,
+            eval_mode,
+            TimestampMicrosecondType,
+            timestamp_parser,
+            tz
+        );
+        assert!(
+            result.is_err(),
+            "ANSI mode should return Err for an invalid timestamp string"
+        );
+    }
+
     #[test]
     fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> {
         // prepare input data
diff --git a/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala 
b/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
index a8dea4cf4..284671896 100644
--- a/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
+++ b/spark/src/main/scala/org/apache/comet/SparkErrorConverter.scala
@@ -100,7 +100,7 @@ object SparkErrorConverter extends ShimSparkErrorConverter {
       case None => Array.empty[QueryContext] // No context
     }
 
-    val summary: String = errorJson.summary.orNull
+    val summary: String = errorJson.summary.getOrElse("")
 
     // Delegate to version-specific shim - let conversion exceptions propagate
     val optEx = convertErrorType(errorJson.errorType, errorClass, params, 
sparkContext, summary)
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 95d536690..d50aa5d8d 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -217,10 +217,7 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
         Compatible(Some("Only supports years between 262143 BC and 262142 AD"))
       case DataTypes.TimestampType if timeZoneId.exists(tz => tz != "UTC") =>
         Incompatible(Some(s"Cast will use UTC instead of $timeZoneId"))
-      case DataTypes.TimestampType if evalMode == CometEvalMode.ANSI =>
-        Incompatible(Some("ANSI mode not supported"))
       case DataTypes.TimestampType =>
-        // https://github.com/apache/datafusion-comet/issues/328
         Incompatible(Some("Not all valid formats are supported"))
       case _ =>
         unsupported(DataTypes.StringType, toType)
diff --git 
a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
 
b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index e4ec9e006..6eee3f5bc 100644
--- 
a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++ 
b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -23,7 +23,7 @@ import java.io.FileNotFoundException
 
 import scala.util.matching.Regex
 
-import org.apache.spark.{QueryContext, SparkException}
+import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
 import org.apache.spark.sql.catalyst.trees.SQLQueryContext
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
@@ -172,6 +172,22 @@ trait ShimSparkErrorConverter {
           QueryExecutionErrors
             .invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
 
+      case "InvalidInputInCastToDatetime" =>
+        val expression =
+          s"'${params("value").toString.replace("\\", "\\\\").replace("'", 
"\\'")}'"
+        val sourceType = s""""${params("fromType").toString}""""
+        val targetType = s""""${params("toType").toString}""""
+        Some(
+          new SparkDateTimeException(
+            errorClass = "CAST_INVALID_INPUT",
+            messageParameters = Map(
+              "expression" -> expression,
+              "sourceType" -> sourceType,
+              "targetType" -> targetType,
+              "ansiConfig" -> "\"spark.sql.ansi.enabled\""),
+            context = context,
+            summary = summary))
+
       case "CastOverFlow" =>
         val fromType = getDataType(params("fromType").toString)
         val toType = getDataType(params("toType").toString)
diff --git 
a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
 
b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index 41f461100..75316c51e 100644
--- 
a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++ 
b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -23,7 +23,7 @@ import java.io.FileNotFoundException
 
 import scala.util.matching.Regex
 
-import org.apache.spark.{QueryContext, SparkException}
+import org.apache.spark.{QueryContext, SparkDateTimeException, SparkException}
 import org.apache.spark.sql.catalyst.trees.SQLQueryContext
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types._
@@ -170,6 +170,22 @@ trait ShimSparkErrorConverter {
           QueryExecutionErrors
             .invalidInputInCastToNumberError(targetType, str, sqlCtx(context)))
 
+      case "InvalidInputInCastToDatetime" =>
+        val expression =
+          s"'${params("value").toString.replace("\\", "\\\\").replace("'", 
"\\'")}'"
+        val sourceType = s""""${params("fromType").toString}""""
+        val targetType = s""""${params("toType").toString}""""
+        Some(
+          new SparkDateTimeException(
+            errorClass = "CAST_INVALID_INPUT",
+            messageParameters = Map(
+              "expression" -> expression,
+              "sourceType" -> sourceType,
+              "targetType" -> targetType,
+              "ansiConfig" -> "\"spark.sql.ansi.enabled\""),
+            context = context,
+            summary = summary))
+
       case "CastOverFlow" =>
         val fromType = getDataType(params("fromType").toString)
         val toType = getDataType(params("toType").toString)
diff --git 
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
 
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index f906db140..fc13a58a4 100644
--- 
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -182,6 +182,13 @@ trait ShimSparkErrorConverter {
           QueryExecutionErrors
             .invalidInputInCastToNumberError(targetType, str, 
context.headOption.orNull))
 
+      case "InvalidInputInCastToDatetime" =>
+        val str = UTF8String.fromString(params("value").toString)
+        val targetType = getDataType(params("toType").toString)
+        Some(
+          QueryExecutionErrors
+            .invalidInputInCastToDatetimeError(str, targetType, 
context.headOption.orNull))
+
       case "CastOverFlow" =>
         val fromType = getDataType(params("fromType").toString)
         val toType = getDataType(params("toType").toString)
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 3d9acc39e..f213e90e0 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -23,7 +23,6 @@ import java.io.File
 
 import scala.collection.mutable.ListBuffer
 import scala.util.Random
-import scala.util.matching.Regex
 
 import org.apache.hadoop.fs.Path
 import org.apache.spark.sql.{CometTestBase, DataFrame, Row, SaveMode}
@@ -916,11 +915,15 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       "\r\n 962 \r\n",
       "\r\n 62 \r\n")
 
-    // due to limitations of NaiveDate we only support years between 262143 BC 
and 262142 AD"
-    val unsupportedYearPattern: Regex = "^\\s*[0-9]{5,}".r
+    // due to limitations of NaiveDate we only support years between 262143 BC 
and 262142 AD
+    // Filter out strings where the leading digit sequence represents a year > 
262142.
+    // All 5-digit years (10000-99999) are within bounds; only 6-digit years 
may exceed the limit.
     val fuzzDates = gen
       .generateStrings(dataSize, datePattern, 8)
-      .filterNot(str => unsupportedYearPattern.findFirstMatchIn(str).isDefined)
+      .filterNot { str =>
+        val yearStr = str.trim.takeWhile(_.isDigit)
+        yearStr.length > 6 || (yearStr.length == 6 && yearStr.toInt > 262142)
+      }
     castTest((validDates ++ invalidDates ++ fuzzDates).toDF("a"), 
DataTypes.DateType)
   }
 
@@ -951,19 +954,47 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  test("cast StringType to TimestampType disabled by default") {
-    withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC")) {
-      val values = Seq("2020-01-01T12:34:56.123456", "T2").toDF("a")
-      castFallbackTest(
-        values.toDF("a"),
-        DataTypes.TimestampType,
-        "Not all valid formats are supported")
+  test("cast StringType to TimestampType - UTC") {
+    withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC") {
+      val values = Seq(
+        "2020",
+        "2020-01",
+        "2020-01-01",
+        "2020-01-01T12",
+        "2020-01-01T12:34",
+        "2020-01-01T12:34:56",
+        "2020-01-01T12:34:56.123456",
+        "T2",
+        "-9?",
+        "0100",
+        "0100-01",
+        "0100-01-01",
+        "0100-01-01T12",
+        "0100-01-01T12:34",
+        "0100-01-01T12:34:56",
+        "0100-01-01T12:34:56.123456",
+        "10000",
+        "10000-01",
+        "10000-01-01",
+        "10000-01-01T12",
+        "10000-01-01T12:34",
+        "10000-01-01T12:34:56",
+        "10000-01-01T12:34:56.123456",
+        "213170",
+        "213170-06",
+        "213170-06-15",
+        "213170-06-15T12",
+        "213170-06-15T12:34",
+        "213170-06-15T12:34:56",
+        "213170-06-15T12:34:56.123456")
+      castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
     }
   }
 
   ignore("cast StringType to TimestampType") {
-    // https://github.com/apache/datafusion-comet/issues/328
-    withSQLConf((CometConf.getExprAllowIncompatConfigKey(classOf[Cast]), 
"true")) {
+    // TODO: enable once all Spark timestamp formats are supported natively.
+    // Currently missing: time-only formats with colon (e.g. "T12:34", "4:4").
+    withSQLConf((SQLConf.SESSION_LOCAL_TIMEZONE.key, "UTC")) {
       val values = Seq("2020-01-01T12:34:56.123456", "T2") ++ 
gen.generateStrings(
         dataSize,
         timestampPattern,
@@ -994,6 +1025,13 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         "2020-01-01T12:34:56.123456",
         "T2",
         "-9?",
+        "100",
+        "100-01",
+        "100-01-01",
+        "100-01-01T12",
+        "100-01-01T12:34",
+        "100-01-01T12:34:56",
+        "100-01-01T12:34:56.123456",
         "0100",
         "0100-01",
         "0100-01-01",
@@ -1010,14 +1048,6 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
         "10000-01-01T12:34:56.123456")
       castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
     }
-
-    // test for invalid inputs
-    withSQLConf(
-      SQLConf.SESSION_LOCAL_TIMEZONE.key -> "UTC",
-      CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true") {
-      val values = Seq("-9?", "1-", "0.5")
-      castTimestampTest(values.toDF("a"), DataTypes.TimestampType)
-    }
   }
 
   // CAST from BinaryType
@@ -1476,7 +1506,10 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
-  private def castTimestampTest(input: DataFrame, toType: DataType) = {
+  private def castTimestampTest(
+      input: DataFrame,
+      toType: DataType,
+      assertNative: Boolean = false) = {
     withTempPath { dir =>
       val data = roundtripParquet(input, dir).coalesce(1)
       data.createOrReplaceTempView("t")
@@ -1484,12 +1517,47 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
         // cast() should return null for invalid inputs when ansi mode is 
disabled
         val df = data.withColumn("converted", col("a").cast(toType))
-        checkSparkAnswer(df)
+        if (assertNative) {
+          checkSparkAnswerAndOperator(df)
+        } else {
+          checkSparkAnswer(df)
+        }
 
         // try_cast() should always return null for invalid inputs
         val df2 = spark.sql(s"select try_cast(a as ${toType.sql}) from t")
         checkSparkAnswer(df2)
       }
+
+      // with ANSI enabled, we should produce the same exception as Spark
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+        val df = data.withColumn("converted", col("a").cast(toType))
+        checkSparkAnswerMaybeThrows(df) match {
+          case (None, None) =>
+          // both succeeded, results already compared
+          case (None, Some(e)) =>
+            throw e
+          case (Some(e), None) =>
+            fail(s"Comet should have failed with ${e.getCause.getMessage}")
+          case (Some(sparkException), Some(cometException)) =>
+            val sparkMessage =
+              if (sparkException.getCause != null) 
sparkException.getCause.getMessage
+              else sparkException.getMessage
+            val cometMessage =
+              if (cometException.getCause != null) 
cometException.getCause.getMessage
+              else cometException.getMessage
+            if (CometSparkSessionExtensions.isSpark40Plus) {
+              assert(sparkMessage.contains("SQLSTATE"))
+              assert(
+                sparkMessage.startsWith(
+                  cometMessage.substring(0, math.min(40, 
cometMessage.length))))
+            } else {
+              assert(cometMessage == sparkMessage)
+            }
+        }
+        // try_cast()
+        val dfTryCast = spark.sql(s"select try_cast(a as ${toType.sql}) from 
t")
+        checkSparkAnswer(dfTryCast)
+      }
     }
   }
 
@@ -1575,7 +1643,9 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                   // context eagerly so it displays the call site at the
                   // line of code where the cast method was called, whereas 
spark grabs the context
                   // lazily and displays the call site at the line of code 
where the error is checked.
-                  assert(sparkMessage.startsWith(cometMessage.substring(0, 
40)))
+                  assert(
+                    sparkMessage.startsWith(
+                      cometMessage.substring(0, math.min(40, 
cometMessage.length))))
                 } else {
                   // for Spark 3.4 we expect to reproduce the error message 
exactly
                   assert(cometMessage == sparkMessage)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to