This is an automated email from the ASF dual-hosted git repository.

parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 34bbe78f8 feat: Cast numeric (non int)  to timestamp (#3559)
34bbe78f8 is described below

commit 34bbe78f87b36e596523de45fac3c7a54fc72aae
Author: Bhargava Vadlamani <[email protected]>
AuthorDate: Fri Mar 6 13:52:56 2026 -0800

    feat: Cast numeric (non int)  to timestamp (#3559)
    
    * float_to_timestamp
    
    * non_numeric_to_timestamp
---
 native/spark-expr/Cargo.toml                       |   4 +
 .../benches/cast_non_int_numeric_timestamp.rs      | 143 ++++++++++++++
 native/spark-expr/src/conversion_funcs/boolean.rs  |  45 ++++-
 native/spark-expr/src/conversion_funcs/cast.rs     |  16 +-
 native/spark-expr/src/conversion_funcs/numeric.rs  | 212 ++++++++++++++++++++-
 .../org/apache/comet/expressions/CometCast.scala   |  31 +--
 .../scala/org/apache/comet/CometCastSuite.scala    |  79 +++++---
 .../scala/org/apache/spark/sql/CometTestBase.scala |  49 +++++
 8 files changed, 531 insertions(+), 48 deletions(-)

diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index e7c238f7e..b014c49a2 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs"
 [[bench]]
 name = "cast_from_boolean"
 harness = false
+
+[[bench]]
+name = "cast_non_int_numeric_timestamp"
+harness = false
diff --git a/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs 
b/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs
new file mode 100644
index 000000000..ea1a85e40
--- /dev/null
+++ b/native/spark-expr/benches/cast_non_int_numeric_timestamp.rs
@@ -0,0 +1,143 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::builder::{BooleanBuilder, Decimal128Builder, Float32Builder, 
Float64Builder};
+use arrow::array::RecordBatch;
+use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
+use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion::physical_expr::{expressions::Column, PhysicalExpr};
+use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
+use std::sync::Arc;
+
+const BATCH_SIZE: usize = 8192;
+
+fn criterion_benchmark(c: &mut Criterion) {
+    let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", 
false);
+    let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond, 
Some("UTC".into()));
+
+    let mut group = c.benchmark_group("cast_non_int_numeric_to_timestamp");
+
+    // Float32 -> Timestamp
+    let batch_f32 = create_float32_batch();
+    let expr_f32 = Arc::new(Column::new("a", 0));
+    let cast_f32_to_ts = Cast::new(expr_f32, timestamp_type.clone(), 
spark_cast_options.clone());
+    group.bench_function("cast_f32_to_timestamp", |b| {
+        b.iter(|| cast_f32_to_ts.evaluate(&batch_f32).unwrap());
+    });
+
+    // Float64 -> Timestamp
+    let batch_f64 = create_float64_batch();
+    let expr_f64 = Arc::new(Column::new("a", 0));
+    let cast_f64_to_ts = Cast::new(expr_f64, timestamp_type.clone(), 
spark_cast_options.clone());
+    group.bench_function("cast_f64_to_timestamp", |b| {
+        b.iter(|| cast_f64_to_ts.evaluate(&batch_f64).unwrap());
+    });
+
+    // Boolean -> Timestamp
+    let batch_bool = create_boolean_batch();
+    let expr_bool = Arc::new(Column::new("a", 0));
+    let cast_bool_to_ts = Cast::new(
+        expr_bool,
+        timestamp_type.clone(),
+        spark_cast_options.clone(),
+    );
+    group.bench_function("cast_bool_to_timestamp", |b| {
+        b.iter(|| cast_bool_to_ts.evaluate(&batch_bool).unwrap());
+    });
+
+    // Decimal128 -> Timestamp
+    let batch_decimal = create_decimal128_batch();
+    let expr_decimal = Arc::new(Column::new("a", 0));
+    let cast_decimal_to_ts = Cast::new(
+        expr_decimal,
+        timestamp_type.clone(),
+        spark_cast_options.clone(),
+    );
+    group.bench_function("cast_decimal_to_timestamp", |b| {
+        b.iter(|| cast_decimal_to_ts.evaluate(&batch_decimal).unwrap());
+    });
+
+    group.finish();
+}
+
+fn create_float32_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
+    let mut b = Float32Builder::with_capacity(BATCH_SIZE);
+    for i in 0..BATCH_SIZE {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(rand::random::<f32>());
+        }
+    }
+    RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_float64_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, 
true)]));
+    let mut b = Float64Builder::with_capacity(BATCH_SIZE);
+    for i in 0..BATCH_SIZE {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(rand::random::<f64>());
+        }
+    }
+    RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_boolean_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, 
true)]));
+    let mut b = BooleanBuilder::with_capacity(BATCH_SIZE);
+    for i in 0..BATCH_SIZE {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(rand::random::<bool>());
+        }
+    }
+    RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_decimal128_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new(
+        "a",
+        DataType::Decimal128(18, 6),
+        true,
+    )]));
+    let mut b = Decimal128Builder::with_capacity(BATCH_SIZE);
+    for i in 0..BATCH_SIZE {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(i as i128 * 1_000_000);
+        }
+    }
+    let array = b.finish().with_precision_and_scale(18, 6).unwrap();
+    RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
+}
+
+fn config() -> Criterion {
+    Criterion::default()
+}
+
+criterion_group! {
+    name = benches;
+    config = config();
+    targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/native/spark-expr/src/conversion_funcs/boolean.rs 
b/native/spark-expr/src/conversion_funcs/boolean.rs
index db288fa32..49855790b 100644
--- a/native/spark-expr/src/conversion_funcs/boolean.rs
+++ b/native/spark-expr/src/conversion_funcs/boolean.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use crate::SparkResult;
-use arrow::array::{ArrayRef, AsArray, Decimal128Array};
+use arrow::array::{Array, ArrayRef, AsArray, Decimal128Array, 
TimestampMicrosecondBuilder};
 use arrow::datatypes::DataType;
 use std::sync::Arc;
 
@@ -28,7 +28,6 @@ pub fn is_df_cast_from_bool_spark_compatible(to_type: 
&DataType) -> bool {
     )
 }
 
-// only DF incompatible boolean cast
 pub fn cast_boolean_to_decimal(
     array: &ArrayRef,
     precision: u8,
@@ -43,6 +42,25 @@ pub fn cast_boolean_to_decimal(
     Ok(Arc::new(result.with_precision_and_scale(precision, scale)?))
 }
 
+pub(crate) fn cast_boolean_to_timestamp(
+    array_ref: &ArrayRef,
+    target_tz: &Option<Arc<str>>,
+) -> SparkResult<ArrayRef> {
+    let bool_array = array_ref.as_boolean();
+    let mut builder = 
TimestampMicrosecondBuilder::with_capacity(bool_array.len());
+
+    for i in 0..bool_array.len() {
+        if bool_array.is_null(i) {
+            builder.append_null();
+        } else {
+            let micros = if bool_array.value(i) { 1 } else { 0 };
+            builder.append_value(micros);
+        }
+    }
+
+    Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as 
ArrayRef)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -53,6 +71,7 @@ mod tests {
         Int64Array, Int8Array, StringArray,
     };
     use arrow::datatypes::DataType::Decimal128;
+    use arrow::datatypes::TimestampMicrosecondType;
     use std::sync::Arc;
 
     fn test_input_bool_array() -> ArrayRef {
@@ -193,4 +212,26 @@ mod tests {
         assert_eq!(arr.value(1), expected_arr.value(1));
         assert!(arr.is_null(2));
     }
+
+    #[test]
+    fn test_cast_boolean_to_timestamp() {
+        let timezones: [Option<Arc<str>>; 3] = [
+            Some(Arc::from("UTC")),
+            Some(Arc::from("America/Los_Angeles")),
+            None,
+        ];
+
+        for tz in &timezones {
+            let bool_array: ArrayRef =
+                Arc::new(BooleanArray::from(vec![Some(true), Some(false), 
None]));
+
+            let result = cast_boolean_to_timestamp(&bool_array, tz).unwrap();
+            let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+            assert_eq!(ts_array.value(0), 1); // true -> 1 microsecond
+            assert_eq!(ts_array.value(1), 0); // false -> 0 (epoch)
+            assert!(ts_array.is_null(2));
+            assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+        }
+    }
 }
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index ff09dbe06..a9e688814 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -16,14 +16,15 @@
 // under the License.
 
 use crate::conversion_funcs::boolean::{
-    cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible,
+    cast_boolean_to_decimal, cast_boolean_to_timestamp, 
is_df_cast_from_bool_spark_compatible,
 };
 use crate::conversion_funcs::numeric::{
-    cast_float32_to_decimal128, cast_float64_to_decimal128, 
cast_int_to_decimal128,
-    cast_int_to_timestamp, is_df_cast_from_decimal_spark_compatible,
-    is_df_cast_from_float_spark_compatible, 
is_df_cast_from_int_spark_compatible,
-    spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, 
spark_cast_float64_to_utf8,
-    spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral,
+    cast_decimal_to_timestamp, cast_float32_to_decimal128, 
cast_float64_to_decimal128,
+    cast_float_to_timestamp, cast_int_to_decimal128, cast_int_to_timestamp,
+    is_df_cast_from_decimal_spark_compatible, 
is_df_cast_from_float_spark_compatible,
+    is_df_cast_from_int_spark_compatible, spark_cast_decimal_to_boolean,
+    spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, 
spark_cast_int_to_int,
+    spark_cast_nonintegral_numeric_to_integral,
 };
 use crate::conversion_funcs::string::{
     cast_string_to_date, cast_string_to_decimal, cast_string_to_float, 
cast_string_to_int,
@@ -384,6 +385,9 @@ pub(crate) fn cast_array(
             cast_boolean_to_decimal(&array, *precision, *scale)
         }
         (Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => 
cast_int_to_timestamp(&array, tz),
+        (Float32 | Float64, Timestamp(_, tz)) => 
cast_float_to_timestamp(&array, tz, eval_mode),
+        (Boolean, Timestamp(_, tz)) => cast_boolean_to_timestamp(&array, tz),
+        (Decimal128(_, scale), Timestamp(_, tz)) => 
cast_decimal_to_timestamp(&array, tz, *scale),
         _ if cast_options.is_adapting_schema
             || is_datafusion_spark_compatible(&from_type, to_type) =>
         {
diff --git a/native/spark-expr/src/conversion_funcs/numeric.rs 
b/native/spark-expr/src/conversion_funcs/numeric.rs
index d204e2871..59a65fb49 100644
--- a/native/spark-expr/src/conversion_funcs/numeric.rs
+++ b/native/spark-expr/src/conversion_funcs/numeric.rs
@@ -24,7 +24,7 @@ use arrow::array::{
     OffsetSizeTrait, PrimitiveArray, TimestampMicrosecondBuilder,
 };
 use arrow::datatypes::{
-    is_validate_decimal_precision, ArrowPrimitiveType, DataType, 
Decimal128Type, Float32Type,
+    i256, is_validate_decimal_precision, ArrowPrimitiveType, DataType, 
Decimal128Type, Float32Type,
     Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
 };
 use num::{cast::AsPrimitive, ToPrimitive, Zero};
@@ -75,6 +75,56 @@ pub(crate) fn 
is_df_cast_from_decimal_spark_compatible(to_type: &DataType) -> bo
     )
 }
 
+macro_rules! cast_float_to_timestamp_impl {
+    ($array:expr, $builder:expr, $primitive_type:ty, $eval_mode:expr) => {{
+        let arr = $array.as_primitive::<$primitive_type>();
+        for i in 0..arr.len() {
+            if arr.is_null(i) {
+                $builder.append_null();
+            } else {
+                let val = arr.value(i) as f64;
+                // Path 1: NaN/Infinity check - error says TIMESTAMP
+                if val.is_nan() || val.is_infinite() {
+                    if $eval_mode == EvalMode::Ansi {
+                        return Err(SparkError::CastInvalidValue {
+                            value: val.to_string(),
+                            from_type: "DOUBLE".to_string(),
+                            to_type: "TIMESTAMP".to_string(),
+                        });
+                    }
+                    $builder.append_null();
+                } else {
+                    // Path 2: Multiply then check overflow - error says BIGINT
+                    let micros = val * MICROS_PER_SECOND as f64;
+                    if micros.floor() <= i64::MAX as f64 && micros.ceil() >= 
i64::MIN as f64 {
+                        $builder.append_value(micros as i64);
+                    } else {
+                        if $eval_mode == EvalMode::Ansi {
+                            let value_str = if micros.is_infinite() {
+                                if micros.is_sign_positive() {
+                                    "Infinity".to_string()
+                                } else {
+                                    "-Infinity".to_string()
+                                }
+                            } else if micros.is_nan() {
+                                "NaN".to_string()
+                            } else {
+                                format!("{:e}", micros).to_uppercase() + "D"
+                            };
+                            return Err(SparkError::CastOverFlow {
+                                value: value_str,
+                                from_type: "DOUBLE".to_string(),
+                                to_type: "BIGINT".to_string(),
+                            });
+                        }
+                        $builder.append_null();
+                    }
+                }
+            }
+        }
+    }};
+}
+
 macro_rules! cast_float_to_string {
     ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) 
=> {{
 
@@ -913,6 +963,57 @@ pub(crate) fn cast_int_to_timestamp(
     Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as 
ArrayRef)
 }
 
+pub(crate) fn cast_decimal_to_timestamp(
+    array_ref: &ArrayRef,
+    target_tz: &Option<Arc<str>>,
+    scale: i8,
+) -> SparkResult<ArrayRef> {
+    let arr = array_ref.as_primitive::<Decimal128Type>();
+    let scale_factor = 10_i128.pow(scale as u32);
+    let mut builder = TimestampMicrosecondBuilder::with_capacity(arr.len());
+
+    for i in 0..arr.len() {
+        if arr.is_null(i) {
+            builder.append_null();
+        } else {
+            let value = arr.value(i);
+            // Note: spark's big decimal truncates to
+            // long value and does not throw error (in all leval modes)
+            let value_256 = i256::from_i128(value);
+            let micros_256 = value_256 * i256::from_i128(MICROS_PER_SECOND as 
i128);
+            let ts = micros_256 / i256::from_i128(scale_factor);
+            builder.append_value(ts.as_i128() as i64);
+        }
+    }
+
+    Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as 
ArrayRef)
+}
+
+pub(crate) fn cast_float_to_timestamp(
+    array_ref: &ArrayRef,
+    target_tz: &Option<Arc<str>>,
+    eval_mode: EvalMode,
+) -> SparkResult<ArrayRef> {
+    let mut builder = 
TimestampMicrosecondBuilder::with_capacity(array_ref.len());
+
+    match array_ref.data_type() {
+        DataType::Float32 => {
+            cast_float_to_timestamp_impl!(array_ref, builder, Float32Type, 
eval_mode)
+        }
+        DataType::Float64 => {
+            cast_float_to_timestamp_impl!(array_ref, builder, Float64Type, 
eval_mode)
+        }
+        dt => {
+            return Err(SparkError::Internal(format!(
+                "Unsupported type for cast_float_to_timestamp: {:?}",
+                dt
+            )))
+        }
+    }
+
+    Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as 
ArrayRef)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -1100,4 +1201,113 @@ mod tests {
         assert!(casted.is_null(8));
         assert!(casted.is_null(9));
     }
+
+    #[test]
+    fn test_cast_decimal_to_timestamp() {
+        let timezones: [Option<Arc<str>>; 3] = [
+            Some(Arc::from("UTC")),
+            Some(Arc::from("America/Los_Angeles")),
+            None,
+        ];
+
+        for tz in &timezones {
+            // Decimal128 with scale 6
+            let decimal_array: ArrayRef = Arc::new(
+                Decimal128Array::from(vec![
+                    Some(0_i128),
+                    Some(1_000_000_i128),
+                    Some(-1_000_000_i128),
+                    Some(1_500_000_i128),
+                    Some(123_456_789_i128),
+                    None,
+                ])
+                .with_precision_and_scale(18, 6)
+                .unwrap(),
+            );
+
+            let result = cast_decimal_to_timestamp(&decimal_array, tz, 
6).unwrap();
+            let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+            assert_eq!(ts_array.value(0), 0);
+            assert_eq!(ts_array.value(1), 1_000_000);
+            assert_eq!(ts_array.value(2), -1_000_000);
+            assert_eq!(ts_array.value(3), 1_500_000);
+            assert_eq!(ts_array.value(4), 123_456_789);
+            assert!(ts_array.is_null(5));
+            assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+
+            // Test with scale 2
+            let decimal_array: ArrayRef = Arc::new(
+                Decimal128Array::from(vec![Some(100_i128), Some(150_i128), 
Some(-250_i128)])
+                    .with_precision_and_scale(10, 2)
+                    .unwrap(),
+            );
+
+            let result = cast_decimal_to_timestamp(&decimal_array, tz, 
2).unwrap();
+            let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+            assert_eq!(ts_array.value(0), 1_000_000);
+            assert_eq!(ts_array.value(1), 1_500_000);
+            assert_eq!(ts_array.value(2), -2_500_000);
+        }
+    }
+
+    #[test]
+    fn test_cast_float_to_timestamp() {
+        let timezones: [Option<Arc<str>>; 3] = [
+            Some(Arc::from("UTC")),
+            Some(Arc::from("America/Los_Angeles")),
+            None,
+        ];
+        let eval_modes = [EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try];
+
+        for tz in &timezones {
+            for eval_mode in &eval_modes {
+                // Float64 tests
+                let f64_array: ArrayRef = Arc::new(Float64Array::from(vec![
+                    Some(0.0),
+                    Some(1.0),
+                    Some(-1.0),
+                    Some(1.5),
+                    Some(0.000001),
+                    None,
+                ]));
+
+                let result = cast_float_to_timestamp(&f64_array, tz, 
*eval_mode).unwrap();
+                let ts_array = 
result.as_primitive::<TimestampMicrosecondType>();
+
+                assert_eq!(ts_array.value(0), 0);
+                assert_eq!(ts_array.value(1), 1_000_000);
+                assert_eq!(ts_array.value(2), -1_000_000);
+                assert_eq!(ts_array.value(3), 1_500_000);
+                assert_eq!(ts_array.value(4), 1);
+                assert!(ts_array.is_null(5));
+                assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| 
s.as_ref()));
+
+                // Float32 tests
+                let f32_array: ArrayRef = Arc::new(Float32Array::from(vec![
+                    Some(0.0_f32),
+                    Some(1.0_f32),
+                    Some(-1.0_f32),
+                    None,
+                ]));
+
+                let result = cast_float_to_timestamp(&f32_array, tz, 
*eval_mode).unwrap();
+                let ts_array = 
result.as_primitive::<TimestampMicrosecondType>();
+
+                assert_eq!(ts_array.value(0), 0);
+                assert_eq!(ts_array.value(1), 1_000_000);
+                assert_eq!(ts_array.value(2), -1_000_000);
+                assert!(ts_array.is_null(3));
+            }
+        }
+
+        // ANSI mode errors on NaN/Infinity
+        let tz = &Some(Arc::from("UTC"));
+        let f64_nan: ArrayRef = 
Arc::new(Float64Array::from(vec![Some(f64::NAN)]));
+        assert!(cast_float_to_timestamp(&f64_nan, tz, 
EvalMode::Ansi).is_err());
+
+        let f64_inf: ArrayRef = 
Arc::new(Float64Array::from(vec![Some(f64::INFINITY)]));
+        assert!(cast_float_to_timestamp(&f64_inf, tz, 
EvalMode::Ansi).is_err());
+    }
 }
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 15dfcb2d7..95d536690 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -21,7 +21,7 @@ package org.apache.comet.expressions
 
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Expression, 
Literal}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, 
DecimalType, NullType, StructType}
+import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, 
DecimalType, NullType, StructType, TimestampType}
 
 import org.apache.comet.CometConf
 import org.apache.comet.CometSparkSessionExtensions.withInfo
@@ -63,16 +63,17 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
       cast: Cast,
       inputs: Seq[Attribute],
       binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val cometEvalMode = evalMode(cast)
     cast.child match {
       case _: Literal =>
         exprToProtoInternal(Literal.create(cast.eval(), cast.dataType), 
inputs, binding)
       case _ =>
-        if (isAlwaysCastToNull(cast.child.dataType, cast.dataType, 
evalMode(cast))) {
+        if (isAlwaysCastToNull(cast.child.dataType, cast.dataType, 
cometEvalMode)) {
           exprToProtoInternal(Literal.create(null, cast.dataType), inputs, 
binding)
         } else {
           val childExpr = exprToProtoInternal(cast.child, inputs, binding)
           if (childExpr.isDefined) {
-            castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get, 
evalMode(cast))
+            castToProto(cast, cast.timeZoneId, cast.dataType, childExpr.get, 
cometEvalMode)
           } else {
             withInfo(cast, cast.child)
             None
@@ -165,7 +166,7 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
       case (_: DecimalType, _) =>
         canCastFromDecimal(toType)
       case (DataTypes.BooleanType, _) =>
-        canCastFromBoolean(toType)
+        canCastFromBoolean(toType, evalMode)
       case (DataTypes.ByteType, _) =>
         canCastFromByte(toType, evalMode)
       case (DataTypes.ShortType, _) =>
@@ -282,12 +283,15 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
     }
   }
 
-  private def canCastFromBoolean(toType: DataType): SupportLevel = toType 
match {
-    case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | 
DataTypes.LongType |
-        DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
-      Compatible()
-    case _ => unsupported(DataTypes.BooleanType, toType)
-  }
+  private def canCastFromBoolean(toType: DataType, evalMode: 
CometEvalMode.Value): SupportLevel =
+    toType match {
+      case DataTypes.ByteType | DataTypes.ShortType | DataTypes.IntegerType | 
DataTypes.LongType |
+          DataTypes.FloatType | DataTypes.DoubleType | _: DecimalType =>
+        Compatible()
+      case _: TimestampType if evalMode == CometEvalMode.LEGACY =>
+        Compatible()
+      case _ => unsupported(DataTypes.BooleanType, toType)
+    }
 
   private def canCastFromByte(toType: DataType, evalMode: 
CometEvalMode.Value): SupportLevel =
     toType match {
@@ -357,7 +361,7 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
 
   private def canCastFromFloat(toType: DataType): SupportLevel = toType match {
     case DataTypes.BooleanType | DataTypes.DoubleType | DataTypes.ByteType | 
DataTypes.ShortType |
-        DataTypes.IntegerType | DataTypes.LongType =>
+        DataTypes.IntegerType | DataTypes.LongType | DataTypes.TimestampType =>
       Compatible()
     case _: DecimalType =>
       // https://github.com/apache/datafusion-comet/issues/1371
@@ -368,7 +372,7 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
 
   private def canCastFromDouble(toType: DataType): SupportLevel = toType match 
{
     case DataTypes.BooleanType | DataTypes.FloatType | DataTypes.ByteType | 
DataTypes.ShortType |
-        DataTypes.IntegerType | DataTypes.LongType =>
+        DataTypes.IntegerType | DataTypes.LongType | DataTypes.TimestampType =>
       Compatible()
     case _: DecimalType =>
       // https://github.com/apache/datafusion-comet/issues/1371
@@ -378,7 +382,8 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
 
   private def canCastFromDecimal(toType: DataType): SupportLevel = toType 
match {
     case DataTypes.FloatType | DataTypes.DoubleType | DataTypes.ByteType | 
DataTypes.ShortType |
-        DataTypes.IntegerType | DataTypes.LongType | DataTypes.BooleanType =>
+        DataTypes.IntegerType | DataTypes.LongType | DataTypes.BooleanType |
+        DataTypes.TimestampType =>
       Compatible()
     case _ => Unsupported(Some(s"Cast from DecimalType to $toType is not 
supported"))
   }
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 72c2390d7..48242a978 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -167,9 +167,9 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateBools(), DataTypes.StringType)
   }
 
-  ignore("cast BooleanType to TimestampType") {
-    // Arrow error: Cast error: Casting from Boolean to Timestamp(Microsecond, 
Some("UTC")) not supported
-    castTest(generateBools(), DataTypes.TimestampType)
+  test("cast BooleanType to TimestampType") {
+    // Spark does not support ANSI or Try mode for Boolean to Timestamp casts
+    castTest(generateBools(), DataTypes.TimestampType, testAnsi = false, 
testTry = false)
   }
 
   // CAST from ByteType
@@ -504,9 +504,13 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(withNulls(values).toDF("a"), DataTypes.StringType)
   }
 
-  ignore("cast FloatType to TimestampType") {
-    // java.lang.ArithmeticException: long overflow
-    castTest(generateFloats(), DataTypes.TimestampType)
+  test("cast FloatType to TimestampType") {
+    compatibleTimezones.foreach { tz =>
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+        // Use useDFDiff to avoid collect() which fails on extreme timestamp 
values
+        castTest(generateFloats(), DataTypes.TimestampType, useDataFrameDiff = 
true)
+      }
+    }
   }
 
   // CAST from DoubleType
@@ -560,9 +564,13 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(withNulls(values).toDF("a"), DataTypes.StringType)
   }
 
-  ignore("cast DoubleType to TimestampType") {
-    // java.lang.ArithmeticException: long overflow
-    castTest(generateDoubles(), DataTypes.TimestampType)
+  test("cast DoubleType to TimestampType") {
+    compatibleTimezones.foreach { tz =>
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+        // Use useDFDiff to avoid collect() which fails on extreme timestamp 
values
+        castTest(generateDoubles(), DataTypes.TimestampType, useDataFrameDiff 
= true)
+      }
+    }
   }
 
   // CAST from DecimalType(10,2)
@@ -627,11 +635,14 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateDecimalsPrecision10Scale2(), DataTypes.StringType)
   }
 
-  ignore("cast DecimalType(10,2) to TimestampType") {
-    // input: -123456.789000000000000000, expected: 1969-12-30 05:42:23.211, 
actual: 1969-12-31 15:59:59.876544
+  test("cast DecimalType(10,2) to TimestampType") {
     castTest(generateDecimalsPrecision10Scale2(), DataTypes.TimestampType)
   }
 
+  test("cast DecimalType(38,10) to TimestampType") {
+    castTest(generateDecimalsPrecision38Scale18(), DataTypes.TimestampType)
+  }
+
   // CAST from StringType
 
   test("cast StringType to BooleanType") {
@@ -1466,7 +1477,8 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       toType: DataType,
       hasIncompatibleType: Boolean = false,
       testAnsi: Boolean = true,
-      testTry: Boolean = true): Unit = {
+      testTry: Boolean = true,
+      useDataFrameDiff: Boolean = false): Unit = {
 
     withTempPath { dir =>
       val data = roundtripParquet(input, dir).coalesce(1)
@@ -1474,22 +1486,29 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       withSQLConf((SQLConf.ANSI_ENABLED.key, "false")) {
         // cast() should return null for invalid inputs when ansi mode is 
disabled
         val df = data.select(col("a"), col("a").cast(toType)).orderBy(col("a"))
-        if (hasIncompatibleType) {
-          checkSparkAnswer(df)
+        if (useDataFrameDiff) {
+          assertDataFrameEqualsWithExceptions(df, assertCometNative = 
!hasIncompatibleType)
         } else {
-          checkSparkAnswerAndOperator(df)
+          if (hasIncompatibleType) {
+            checkSparkAnswer(df)
+          } else {
+            checkSparkAnswerAndOperator(df)
+          }
         }
 
         if (testTry) {
           data.createOrReplaceTempView("t")
-//          try_cast() should always return null for invalid inputs
-//          not using spark DSL since it `try_cast` is only available from 
Spark 4x
-          val df2 =
-            spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order by 
a")
+          // try_cast() should always return null for invalid inputs
+          // not using spark DSL since it `try_cast` is only available from 
Spark 4x
+          val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t 
order by a")
           if (hasIncompatibleType) {
             checkSparkAnswer(df2)
           } else {
-            checkSparkAnswerAndOperator(df2)
+            if (useDataFrameDiff) {
+              assertDataFrameEqualsWithExceptions(df2, assertCometNative = 
!hasIncompatibleType)
+            } else {
+              checkSparkAnswerAndOperator(df2)
+            }
           }
         }
       }
@@ -1502,7 +1521,12 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
           // cast() should throw exception on invalid inputs when ansi mode is 
enabled
           val df = data.withColumn("converted", col("a").cast(toType))
-          checkSparkAnswerMaybeThrows(df) match {
+          val res = if (useDataFrameDiff) {
+            assertDataFrameEqualsWithExceptions(df, assertCometNative = 
!hasIncompatibleType)
+          } else {
+            checkSparkAnswerMaybeThrows(df)
+          }
+          res match {
             case (None, None) =>
             // neither system threw an exception
             case (None, Some(e)) =>
@@ -1546,12 +1570,15 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
                 }
               }
           }
+        }
 
-          // try_cast() should always return null for invalid inputs
-          if (testTry) {
-            data.createOrReplaceTempView("t")
-            val df2 =
-              spark.sql(s"select a, try_cast(a as ${toType.sql}) from t order 
by a")
+        // try_cast() should always return null for invalid inputs
+        if (testTry) {
+          data.createOrReplaceTempView("t")
+          val df2 = spark.sql(s"select a, try_cast(a as ${toType.sql}) from t 
order by a")
+          if (useDataFrameDiff) {
+            assertDataFrameEqualsWithExceptions(df2, assertCometNative = 
!hasIncompatibleType)
+          } else {
             if (hasIncompatibleType) {
               checkSparkAnswer(df2)
             } else {
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 41080ed9e..f831d53bf 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -1276,4 +1276,53 @@ abstract class CometTestBase
     !usingLegacyNativeCometScan(conf) &&
     CometConf.COMET_PARQUET_UNSIGNED_SMALL_INT_CHECK.get(conf)
   }
+
+  /**
+   * Compares Spark and Comet results using foreach() and exceptAll() to avoid 
collect()
+   */
+  protected def assertDataFrameEqualsWithExceptions(
+      df: => DataFrame,
+      assertCometNative: Boolean = true): (Option[Throwable], 
Option[Throwable]) = {
+
+    var expected: Try[Unit] = null
+    withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+      expected = Try(datasetOfRows(spark, df.logicalPlan).foreach(_ => ()))
+    }
+    val actual = Try(datasetOfRows(spark, df.logicalPlan).foreach(_ => ()))
+
+    (expected, actual) match {
+      case (Success(_), Success(_)) =>
+        // compare results and confirm that they match
+        var dfSpark: DataFrame = null
+        withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+          dfSpark = datasetOfRows(spark, df.logicalPlan)
+        }
+        val dfComet = datasetOfRows(spark, df.logicalPlan)
+
+        // Compare schemas
+        assert(
+          dfSpark.schema == dfComet.schema,
+          s"Schema mismatch:\nSpark: ${dfSpark.schema}\nComet: 
${dfComet.schema}")
+
+        val sparkMinusComet = dfSpark.exceptAll(dfComet)
+        val cometMinusSpark = dfComet.exceptAll(dfSpark)
+        val diffCount1 = sparkMinusComet.count()
+        val diffCount2 = cometMinusSpark.count()
+
+        if (diffCount1 > 0 || diffCount2 > 0) {
+          fail(
+            "Results do not match. " +
+              s"Rows in Spark but not Comet: $diffCount1. " +
+              s"Rows in Comet but not Spark: $diffCount2.")
+        }
+
+        if (assertCometNative) {
+          
checkCometOperators(stripAQEPlan(dfComet.queryExecution.executedPlan))
+        }
+
+        (None, None)
+      case _ =>
+        (expected.failed.toOption, actual.failed.toOption)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to