This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 219859bba feat: Support int to timestamp casts (#3541)
219859bba is described below

commit 219859bbae864c53d38fdb23b711284165fcc54b
Author: B Vadlamani <[email protected]>
AuthorDate: Thu Feb 19 07:41:01 2026 -0800

    feat: Support int to timestamp casts (#3541)
---
 native/spark-expr/Cargo.toml                       |   4 +
 native/spark-expr/benches/cast_int_to_timestamp.rs | 131 +++++++++++++++++++++
 native/spark-expr/src/conversion_funcs/cast.rs     | 131 +++++++++++++++++++++
 .../org/apache/comet/expressions/CometCast.scala   |   8 ++
 .../scala/org/apache/comet/CometCastSuite.scala    |  74 ++++++++----
 5 files changed, 328 insertions(+), 20 deletions(-)

diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index fd0a211b2..63e1c0476 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -92,6 +92,10 @@ harness = false
 name = "to_csv"
 harness = false
 
+[[bench]]
+name = "cast_int_to_timestamp"
+harness = false
+
 [[test]]
 name = "test_udf_registration"
 path = "tests/spark_expr_reg.rs"
diff --git a/native/spark-expr/benches/cast_int_to_timestamp.rs 
b/native/spark-expr/benches/cast_int_to_timestamp.rs
new file mode 100644
index 000000000..20143d2b0
--- /dev/null
+++ b/native/spark-expr/benches/cast_int_to_timestamp.rs
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::builder::{Int16Builder, Int32Builder, Int64Builder, 
Int8Builder};
+use arrow::array::RecordBatch;
+use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
+use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion::physical_expr::{expressions::Column, PhysicalExpr};
+use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
+use std::sync::Arc;
+
+const BATCH_SIZE: usize = 8192;
+
+fn criterion_benchmark(c: &mut Criterion) {
+    // Test with UTC timezone
+    let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", 
false);
+    let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond, 
Some("UTC".into()));
+
+    let mut group = c.benchmark_group("cast_int_to_timestamp");
+
+    // Int8 -> Timestamp
+    let batch_i8 = create_int8_batch();
+    let expr_i8 = Arc::new(Column::new("a", 0));
+    let cast_i8_to_ts = Cast::new(expr_i8, timestamp_type.clone(), 
spark_cast_options.clone());
+    group.bench_function("cast_i8_to_timestamp", |b| {
+        b.iter(|| cast_i8_to_ts.evaluate(&batch_i8).unwrap());
+    });
+
+    // Int16 -> Timestamp
+    let batch_i16 = create_int16_batch();
+    let expr_i16 = Arc::new(Column::new("a", 0));
+    let cast_i16_to_ts = Cast::new(expr_i16, timestamp_type.clone(), 
spark_cast_options.clone());
+    group.bench_function("cast_i16_to_timestamp", |b| {
+        b.iter(|| cast_i16_to_ts.evaluate(&batch_i16).unwrap());
+    });
+
+    // Int32 -> Timestamp
+    let batch_i32 = create_int32_batch();
+    let expr_i32 = Arc::new(Column::new("a", 0));
+    let cast_i32_to_ts = Cast::new(expr_i32, timestamp_type.clone(), 
spark_cast_options.clone());
+    group.bench_function("cast_i32_to_timestamp", |b| {
+        b.iter(|| cast_i32_to_ts.evaluate(&batch_i32).unwrap());
+    });
+
+    // Int64 -> Timestamp
+    let batch_i64 = create_int64_batch();
+    let expr_i64 = Arc::new(Column::new("a", 0));
+    let cast_i64_to_ts = Cast::new(expr_i64, timestamp_type.clone(), 
spark_cast_options.clone());
+    group.bench_function("cast_i64_to_timestamp", |b| {
+        b.iter(|| cast_i64_to_ts.evaluate(&batch_i64).unwrap());
+    });
+
+    group.finish();
+}
+
+fn create_int8_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int8, 
true)]));
+    let mut b = Int8Builder::with_capacity(BATCH_SIZE);
+    for i in 0..BATCH_SIZE {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(rand::random::<i8>());
+        }
+    }
+    RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_int16_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int16, 
true)]));
+    let mut b = Int16Builder::with_capacity(BATCH_SIZE);
+    for i in 0..BATCH_SIZE {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(rand::random::<i16>());
+        }
+    }
+    RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_int32_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, 
true)]));
+    let mut b = Int32Builder::with_capacity(BATCH_SIZE);
+    for i in 0..BATCH_SIZE {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(rand::random::<i32>());
+        }
+    }
+    RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_int64_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, 
true)]));
+    let mut b = Int64Builder::with_capacity(BATCH_SIZE);
+    for i in 0..BATCH_SIZE {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            b.append_value(rand::random::<i64>());
+        }
+    }
+    RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn config() -> Criterion {
+    Criterion::default()
+}
+
+criterion_group! {
+    name = benches;
+    config = config();
+    targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index 2809104f2..f5ab83b8a 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -613,6 +613,23 @@ macro_rules! cast_decimal_to_int32_up {
     }};
 }
 
+macro_rules! cast_int_to_timestamp_impl {
+    ($array:expr, $builder:expr, $primitive_type:ty) => {{
+        let arr = $array.as_primitive::<$primitive_type>();
+        for i in 0..arr.len() {
+            if arr.is_null(i) {
+                $builder.append_null();
+            } else {
+                // saturating_mul limits to i64::MIN/MAX on overflow instead 
of panicking,
+                // which could occur when converting extreme values (e.g., 
Long.MIN_VALUE)
+                // matching spark behavior (irrespective of EvalMode)
+                let micros = (arr.value(i) as 
i64).saturating_mul(MICROS_PER_SECOND);
+                $builder.append_value(micros);
+            }
+        }
+    }};
+}
+
 // copied from arrow::dataTypes::Decimal128Type since 
Decimal128Type::format_decimal can't be called directly
 fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
     let (sign, rest) = match value_str.strip_prefix('-') {
@@ -915,6 +932,7 @@ fn cast_array(
         (Boolean, Decimal128(precision, scale)) => {
             cast_boolean_to_decimal(&array, *precision, *scale)
         }
+        (Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) => 
cast_int_to_timestamp(&array, tz),
         _ if cast_options.is_adapting_schema
             || is_datafusion_spark_compatible(from_type, to_type) =>
         {
@@ -933,6 +951,29 @@ fn cast_array(
     Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
 }
 
+fn cast_int_to_timestamp(
+    array_ref: &ArrayRef,
+    target_tz: &Option<Arc<str>>,
+) -> SparkResult<ArrayRef> {
+    // Input is seconds since epoch, multiply by MICROS_PER_SECOND to get 
microseconds.
+    let mut builder = 
TimestampMicrosecondBuilder::with_capacity(array_ref.len());
+
+    match array_ref.data_type() {
+        DataType::Int8 => cast_int_to_timestamp_impl!(array_ref, builder, 
Int8Type),
+        DataType::Int16 => cast_int_to_timestamp_impl!(array_ref, builder, 
Int16Type),
+        DataType::Int32 => cast_int_to_timestamp_impl!(array_ref, builder, 
Int32Type),
+        DataType::Int64 => cast_int_to_timestamp_impl!(array_ref, builder, 
Int64Type),
+        dt => {
+            return Err(SparkError::Internal(format!(
+                "Unsupported type for cast_int_to_timestamp: {:?}",
+                dt
+            )))
+        }
+    }
+
+    Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as 
ArrayRef)
+}
+
 fn cast_date_to_timestamp(
     array_ref: &ArrayRef,
     cast_options: &SparkCastOptions,
@@ -3519,4 +3560,94 @@ mod tests {
         assert_eq!(r#"[null]"#, string_array.value(2));
         assert_eq!(r#"[]"#, string_array.value(3));
     }
+
+    #[test]
+    fn test_cast_int_to_timestamp() {
+        let timezones: [Option<Arc<str>>; 6] = [
+            Some(Arc::from("UTC")),
+            Some(Arc::from("America/New_York")),
+            Some(Arc::from("America/Los_Angeles")),
+            Some(Arc::from("Europe/London")),
+            Some(Arc::from("Asia/Tokyo")),
+            Some(Arc::from("Australia/Sydney")),
+        ];
+
+        for tz in &timezones {
+            let int8_array: ArrayRef = Arc::new(Int8Array::from(vec![
+                Some(0),
+                Some(1),
+                Some(-1),
+                Some(127),
+                Some(-128),
+                None,
+            ]));
+
+            let result = cast_int_to_timestamp(&int8_array, tz).unwrap();
+            let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+            assert_eq!(ts_array.value(0), 0);
+            assert_eq!(ts_array.value(1), 1_000_000);
+            assert_eq!(ts_array.value(2), -1_000_000);
+            assert_eq!(ts_array.value(3), 127_000_000);
+            assert_eq!(ts_array.value(4), -128_000_000);
+            assert!(ts_array.is_null(5));
+            assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+
+            let int16_array: ArrayRef = Arc::new(Int16Array::from(vec![
+                Some(0),
+                Some(1),
+                Some(-1),
+                Some(32767),
+                Some(-32768),
+                None,
+            ]));
+
+            let result = cast_int_to_timestamp(&int16_array, tz).unwrap();
+            let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+            assert_eq!(ts_array.value(0), 0);
+            assert_eq!(ts_array.value(1), 1_000_000);
+            assert_eq!(ts_array.value(2), -1_000_000);
+            assert_eq!(ts_array.value(3), 32_767_000_000_i64);
+            assert_eq!(ts_array.value(4), -32_768_000_000_i64);
+            assert!(ts_array.is_null(5));
+            assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+
+            let int32_array: ArrayRef = Arc::new(Int32Array::from(vec![
+                Some(0),
+                Some(1),
+                Some(-1),
+                Some(1704067200),
+                None,
+            ]));
+
+            let result = cast_int_to_timestamp(&int32_array, tz).unwrap();
+            let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+            assert_eq!(ts_array.value(0), 0);
+            assert_eq!(ts_array.value(1), 1_000_000);
+            assert_eq!(ts_array.value(2), -1_000_000);
+            assert_eq!(ts_array.value(3), 1_704_067_200_000_000_i64);
+            assert!(ts_array.is_null(4));
+            assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+
+            let int64_array: ArrayRef = Arc::new(Int64Array::from(vec![
+                Some(0),
+                Some(1),
+                Some(-1),
+                Some(i64::MAX),
+                Some(i64::MIN),
+            ]));
+
+            let result = cast_int_to_timestamp(&int64_array, tz).unwrap();
+            let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+            assert_eq!(ts_array.value(0), 0);
+            assert_eq!(ts_array.value(1), 1_000_000_i64);
+            assert_eq!(ts_array.value(2), -1_000_000_i64);
+            assert_eq!(ts_array.value(3), i64::MAX);
+            assert_eq!(ts_array.value(4), i64::MIN);
+            assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+        }
+    }
 }
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala 
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 8cbe76a19..15dfcb2d7 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -299,6 +299,8 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
         Compatible()
       case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
         Compatible()
+      case DataTypes.TimestampType =>
+        Compatible()
       case _ =>
         unsupported(DataTypes.ByteType, toType)
     }
@@ -313,6 +315,8 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
         Compatible()
       case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
         Compatible()
+      case DataTypes.TimestampType =>
+        Compatible()
       case _ =>
         unsupported(DataTypes.ShortType, toType)
     }
@@ -328,6 +332,8 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
       case _: DecimalType =>
         Compatible()
       case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => 
Compatible()
+      case DataTypes.TimestampType =>
+        Compatible()
       case _ =>
         unsupported(DataTypes.IntegerType, toType)
     }
@@ -343,6 +349,8 @@ object CometCast extends CometExpressionSerde[Cast] with 
CometExprShim {
       case _: DecimalType =>
         Compatible()
       case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) => 
Compatible()
+      case DataTypes.TimestampType =>
+        Compatible()
       case _ =>
         unsupported(DataTypes.LongType, toType)
     }
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 326904d56..72c2390d7 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -65,6 +65,23 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
   lazy val usingParquetExecWithIncompatTypes: Boolean =
     hasUnsignedSmallIntSafetyCheck(conf)
 
+  // Timezone list to check temporal type casts
+  private val compatibleTimezones = Seq(
+    "UTC",
+    "America/New_York",
+    "America/Chicago",
+    "America/Denver",
+    "America/Los_Angeles",
+    "Europe/London",
+    "Europe/Paris",
+    "Europe/Berlin",
+    "Asia/Tokyo",
+    "Asia/Shanghai",
+    "Asia/Singapore",
+    "Asia/Kolkata",
+    "Australia/Sydney",
+    "Pacific/Auckland")
+
   test("all valid cast combinations covered") {
     val names = testNames
 
@@ -223,12 +240,15 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       testTry = false)
   }
 
-  ignore("cast ByteType to TimestampType") {
-    // input: -1, expected: 1969-12-31 15:59:59.0, actual: 1969-12-31 
15:59:59.999999
-    castTest(
-      generateBytes(),
-      DataTypes.TimestampType,
-      hasIncompatibleType = usingParquetExecWithIncompatTypes)
+  test("cast ByteType to TimestampType") {
+    compatibleTimezones.foreach { tz =>
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+        castTest(
+          generateBytes(),
+          DataTypes.TimestampType,
+          hasIncompatibleType = usingParquetExecWithIncompatTypes)
+      }
+    }
   }
 
   // CAST from ShortType
@@ -300,12 +320,15 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       testTry = false)
   }
 
-  ignore("cast ShortType to TimestampType") {
-    // input: -1003, expected: 1969-12-31 15:43:17.0, actual: 1969-12-31 
15:59:59.998997
-    castTest(
-      generateShorts(),
-      DataTypes.TimestampType,
-      hasIncompatibleType = usingParquetExecWithIncompatTypes)
+  test("cast ShortType to TimestampType") {
+    compatibleTimezones.foreach { tz =>
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+        castTest(
+          generateShorts(),
+          DataTypes.TimestampType,
+          hasIncompatibleType = usingParquetExecWithIncompatTypes)
+      }
+    }
   }
 
   // CAST from integer
@@ -363,9 +386,12 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateInts(), DataTypes.BinaryType, testAnsi = false, testTry = 
false)
   }
 
-  ignore("cast IntegerType to TimestampType") {
-    // input: -1000479329, expected: 1938-04-19 01:04:31.0, actual: 1969-12-31 
15:43:19.520671
-    castTest(generateInts(), DataTypes.TimestampType)
+  test("cast IntegerType to TimestampType") {
+    compatibleTimezones.foreach { tz =>
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+        castTest(generateInts(), DataTypes.TimestampType)
+      }
+    }
   }
 
   // CAST from LongType
@@ -410,9 +436,17 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     castTest(generateLongs(), DataTypes.BinaryType, testAnsi = false, testTry 
= false)
   }
 
-  ignore("cast LongType to TimestampType") {
-    // java.lang.ArithmeticException: long overflow
-    castTest(generateLongs(), DataTypes.TimestampType)
+  test("cast LongType to TimestampType") {
+    // Cast back to long avoids java.sql.Timestamp overflow during collect() 
for extreme values
+    compatibleTimezones.foreach { tz =>
+      withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+        withTable("t1") {
+          generateLongs().write.saveAsTable("t1")
+          val df = spark.sql("select a, cast(cast(a as timestamp) as long) 
from t1")
+          checkSparkAnswerAndOperator(df)
+        }
+      }
+    }
   }
 
   // CAST from FloatType
@@ -1042,13 +1076,13 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   ignore("cast TimestampType to ShortType") {
     // https://github.com/apache/datafusion-comet/issues/352
-    // input: 2023-12-31 10:00:00.0, expected: -21472, actual: null]
+    // input: 2023-12-31 10:00:00.0, expected: -21472, actual: null
     castTest(generateTimestamps(), DataTypes.ShortType)
   }
 
   ignore("cast TimestampType to IntegerType") {
     // https://github.com/apache/datafusion-comet/issues/352
-    // input: 2023-12-31 10:00:00.0, expected: 1704045600, actual: null]
+    // input: 2023-12-31 10:00:00.0, expected: 1704045600, actual: null
     castTest(generateTimestamps(), DataTypes.IntegerType)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to