This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 219859bba feat: Support int to timestamp casts (#3541)
219859bba is described below
commit 219859bbae864c53d38fdb23b711284165fcc54b
Author: B Vadlamani <[email protected]>
AuthorDate: Thu Feb 19 07:41:01 2026 -0800
feat: Support int to timestamp casts (#3541)
---
native/spark-expr/Cargo.toml | 4 +
native/spark-expr/benches/cast_int_to_timestamp.rs | 131 +++++++++++++++++++++
native/spark-expr/src/conversion_funcs/cast.rs | 131 +++++++++++++++++++++
.../org/apache/comet/expressions/CometCast.scala | 8 ++
.../scala/org/apache/comet/CometCastSuite.scala | 74 ++++++++----
5 files changed, 328 insertions(+), 20 deletions(-)
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index fd0a211b2..63e1c0476 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -92,6 +92,10 @@ harness = false
name = "to_csv"
harness = false
+[[bench]]
+name = "cast_int_to_timestamp"
+harness = false
+
[[test]]
name = "test_udf_registration"
path = "tests/spark_expr_reg.rs"
diff --git a/native/spark-expr/benches/cast_int_to_timestamp.rs
b/native/spark-expr/benches/cast_int_to_timestamp.rs
new file mode 100644
index 000000000..20143d2b0
--- /dev/null
+++ b/native/spark-expr/benches/cast_int_to_timestamp.rs
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::builder::{Int16Builder, Int32Builder, Int64Builder,
Int8Builder};
+use arrow::array::RecordBatch;
+use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
+use criterion::{criterion_group, criterion_main, Criterion};
+use datafusion::physical_expr::{expressions::Column, PhysicalExpr};
+use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions};
+use std::sync::Arc;
+
+const BATCH_SIZE: usize = 8192;
+
+fn criterion_benchmark(c: &mut Criterion) {
+ // Test with UTC timezone
+ let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC",
false);
+ let timestamp_type = DataType::Timestamp(TimeUnit::Microsecond,
Some("UTC".into()));
+
+ let mut group = c.benchmark_group("cast_int_to_timestamp");
+
+ // Int8 -> Timestamp
+ let batch_i8 = create_int8_batch();
+ let expr_i8 = Arc::new(Column::new("a", 0));
+ let cast_i8_to_ts = Cast::new(expr_i8, timestamp_type.clone(),
spark_cast_options.clone());
+ group.bench_function("cast_i8_to_timestamp", |b| {
+ b.iter(|| cast_i8_to_ts.evaluate(&batch_i8).unwrap());
+ });
+
+ // Int16 -> Timestamp
+ let batch_i16 = create_int16_batch();
+ let expr_i16 = Arc::new(Column::new("a", 0));
+ let cast_i16_to_ts = Cast::new(expr_i16, timestamp_type.clone(),
spark_cast_options.clone());
+ group.bench_function("cast_i16_to_timestamp", |b| {
+ b.iter(|| cast_i16_to_ts.evaluate(&batch_i16).unwrap());
+ });
+
+ // Int32 -> Timestamp
+ let batch_i32 = create_int32_batch();
+ let expr_i32 = Arc::new(Column::new("a", 0));
+ let cast_i32_to_ts = Cast::new(expr_i32, timestamp_type.clone(),
spark_cast_options.clone());
+ group.bench_function("cast_i32_to_timestamp", |b| {
+ b.iter(|| cast_i32_to_ts.evaluate(&batch_i32).unwrap());
+ });
+
+ // Int64 -> Timestamp
+ let batch_i64 = create_int64_batch();
+ let expr_i64 = Arc::new(Column::new("a", 0));
+ let cast_i64_to_ts = Cast::new(expr_i64, timestamp_type.clone(),
spark_cast_options.clone());
+ group.bench_function("cast_i64_to_timestamp", |b| {
+ b.iter(|| cast_i64_to_ts.evaluate(&batch_i64).unwrap());
+ });
+
+ group.finish();
+}
+
+fn create_int8_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int8,
true)]));
+ let mut b = Int8Builder::with_capacity(BATCH_SIZE);
+ for i in 0..BATCH_SIZE {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(rand::random::<i8>());
+ }
+ }
+ RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_int16_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int16,
true)]));
+ let mut b = Int16Builder::with_capacity(BATCH_SIZE);
+ for i in 0..BATCH_SIZE {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(rand::random::<i16>());
+ }
+ }
+ RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_int32_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32,
true)]));
+ let mut b = Int32Builder::with_capacity(BATCH_SIZE);
+ for i in 0..BATCH_SIZE {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(rand::random::<i32>());
+ }
+ }
+ RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn create_int64_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64,
true)]));
+ let mut b = Int64Builder::with_capacity(BATCH_SIZE);
+ for i in 0..BATCH_SIZE {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ b.append_value(rand::random::<i64>());
+ }
+ }
+ RecordBatch::try_new(schema, vec![Arc::new(b.finish())]).unwrap()
+}
+
+fn config() -> Criterion {
+ Criterion::default()
+}
+
+criterion_group! {
+ name = benches;
+ config = config();
+ targets = criterion_benchmark
+}
+criterion_main!(benches);
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
index 2809104f2..f5ab83b8a 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -613,6 +613,23 @@ macro_rules! cast_decimal_to_int32_up {
}};
}
+macro_rules! cast_int_to_timestamp_impl {
+ ($array:expr, $builder:expr, $primitive_type:ty) => {{
+ let arr = $array.as_primitive::<$primitive_type>();
+ for i in 0..arr.len() {
+ if arr.is_null(i) {
+ $builder.append_null();
+ } else {
+ // saturating_mul limits to i64::MIN/MAX on overflow instead
of panicking,
+ // which could occur when converting extreme values (e.g.,
Long.MIN_VALUE)
+ // matching spark behavior (irrespective of EvalMode)
+ let micros = (arr.value(i) as
i64).saturating_mul(MICROS_PER_SECOND);
+ $builder.append_value(micros);
+ }
+ }
+ }};
+}
+
// copied from arrow::dataTypes::Decimal128Type since
Decimal128Type::format_decimal can't be called directly
fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String {
let (sign, rest) = match value_str.strip_prefix('-') {
@@ -915,6 +932,7 @@ fn cast_array(
(Boolean, Decimal128(precision, scale)) => {
cast_boolean_to_decimal(&array, *precision, *scale)
}
+ (Int8 | Int16 | Int32 | Int64, Timestamp(_, tz)) =>
cast_int_to_timestamp(&array, tz),
_ if cast_options.is_adapting_schema
|| is_datafusion_spark_compatible(from_type, to_type) =>
{
@@ -933,6 +951,29 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}
+fn cast_int_to_timestamp(
+ array_ref: &ArrayRef,
+ target_tz: &Option<Arc<str>>,
+) -> SparkResult<ArrayRef> {
+ // Input is seconds since epoch, multiply by MICROS_PER_SECOND to get
microseconds.
+ let mut builder =
TimestampMicrosecondBuilder::with_capacity(array_ref.len());
+
+ match array_ref.data_type() {
+ DataType::Int8 => cast_int_to_timestamp_impl!(array_ref, builder,
Int8Type),
+ DataType::Int16 => cast_int_to_timestamp_impl!(array_ref, builder,
Int16Type),
+ DataType::Int32 => cast_int_to_timestamp_impl!(array_ref, builder,
Int32Type),
+ DataType::Int64 => cast_int_to_timestamp_impl!(array_ref, builder,
Int64Type),
+ dt => {
+ return Err(SparkError::Internal(format!(
+ "Unsupported type for cast_int_to_timestamp: {:?}",
+ dt
+ )))
+ }
+ }
+
+ Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as
ArrayRef)
+}
+
fn cast_date_to_timestamp(
array_ref: &ArrayRef,
cast_options: &SparkCastOptions,
@@ -3519,4 +3560,94 @@ mod tests {
assert_eq!(r#"[null]"#, string_array.value(2));
assert_eq!(r#"[]"#, string_array.value(3));
}
+
+ #[test]
+ fn test_cast_int_to_timestamp() {
+ let timezones: [Option<Arc<str>>; 6] = [
+ Some(Arc::from("UTC")),
+ Some(Arc::from("America/New_York")),
+ Some(Arc::from("America/Los_Angeles")),
+ Some(Arc::from("Europe/London")),
+ Some(Arc::from("Asia/Tokyo")),
+ Some(Arc::from("Australia/Sydney")),
+ ];
+
+ for tz in &timezones {
+ let int8_array: ArrayRef = Arc::new(Int8Array::from(vec![
+ Some(0),
+ Some(1),
+ Some(-1),
+ Some(127),
+ Some(-128),
+ None,
+ ]));
+
+ let result = cast_int_to_timestamp(&int8_array, tz).unwrap();
+ let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 0);
+ assert_eq!(ts_array.value(1), 1_000_000);
+ assert_eq!(ts_array.value(2), -1_000_000);
+ assert_eq!(ts_array.value(3), 127_000_000);
+ assert_eq!(ts_array.value(4), -128_000_000);
+ assert!(ts_array.is_null(5));
+ assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+
+ let int16_array: ArrayRef = Arc::new(Int16Array::from(vec![
+ Some(0),
+ Some(1),
+ Some(-1),
+ Some(32767),
+ Some(-32768),
+ None,
+ ]));
+
+ let result = cast_int_to_timestamp(&int16_array, tz).unwrap();
+ let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 0);
+ assert_eq!(ts_array.value(1), 1_000_000);
+ assert_eq!(ts_array.value(2), -1_000_000);
+ assert_eq!(ts_array.value(3), 32_767_000_000_i64);
+ assert_eq!(ts_array.value(4), -32_768_000_000_i64);
+ assert!(ts_array.is_null(5));
+ assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+
+ let int32_array: ArrayRef = Arc::new(Int32Array::from(vec![
+ Some(0),
+ Some(1),
+ Some(-1),
+ Some(1704067200),
+ None,
+ ]));
+
+ let result = cast_int_to_timestamp(&int32_array, tz).unwrap();
+ let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 0);
+ assert_eq!(ts_array.value(1), 1_000_000);
+ assert_eq!(ts_array.value(2), -1_000_000);
+ assert_eq!(ts_array.value(3), 1_704_067_200_000_000_i64);
+ assert!(ts_array.is_null(4));
+ assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+
+ let int64_array: ArrayRef = Arc::new(Int64Array::from(vec![
+ Some(0),
+ Some(1),
+ Some(-1),
+ Some(i64::MAX),
+ Some(i64::MIN),
+ ]));
+
+ let result = cast_int_to_timestamp(&int64_array, tz).unwrap();
+ let ts_array = result.as_primitive::<TimestampMicrosecondType>();
+
+ assert_eq!(ts_array.value(0), 0);
+ assert_eq!(ts_array.value(1), 1_000_000_i64);
+ assert_eq!(ts_array.value(2), -1_000_000_i64);
+ assert_eq!(ts_array.value(3), i64::MAX);
+ assert_eq!(ts_array.value(4), i64::MIN);
+ assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref()));
+ }
+ }
}
diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
index 8cbe76a19..15dfcb2d7 100644
--- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
+++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
@@ -299,6 +299,8 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
Compatible()
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
Compatible()
+ case DataTypes.TimestampType =>
+ Compatible()
case _ =>
unsupported(DataTypes.ByteType, toType)
}
@@ -313,6 +315,8 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
Compatible()
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
Compatible()
+ case DataTypes.TimestampType =>
+ Compatible()
case _ =>
unsupported(DataTypes.ShortType, toType)
}
@@ -328,6 +332,8 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
case _: DecimalType =>
Compatible()
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
Compatible()
+ case DataTypes.TimestampType =>
+ Compatible()
case _ =>
unsupported(DataTypes.IntegerType, toType)
}
@@ -343,6 +349,8 @@ object CometCast extends CometExpressionSerde[Cast] with
CometExprShim {
case _: DecimalType =>
Compatible()
case DataTypes.BinaryType if (evalMode == CometEvalMode.LEGACY) =>
Compatible()
+ case DataTypes.TimestampType =>
+ Compatible()
case _ =>
unsupported(DataTypes.LongType, toType)
}
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 326904d56..72c2390d7 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -65,6 +65,23 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
lazy val usingParquetExecWithIncompatTypes: Boolean =
hasUnsignedSmallIntSafetyCheck(conf)
+ // Timezone list to check temporal type casts
+ private val compatibleTimezones = Seq(
+ "UTC",
+ "America/New_York",
+ "America/Chicago",
+ "America/Denver",
+ "America/Los_Angeles",
+ "Europe/London",
+ "Europe/Paris",
+ "Europe/Berlin",
+ "Asia/Tokyo",
+ "Asia/Shanghai",
+ "Asia/Singapore",
+ "Asia/Kolkata",
+ "Australia/Sydney",
+ "Pacific/Auckland")
+
test("all valid cast combinations covered") {
val names = testNames
@@ -223,12 +240,15 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
testTry = false)
}
- ignore("cast ByteType to TimestampType") {
- // input: -1, expected: 1969-12-31 15:59:59.0, actual: 1969-12-31
15:59:59.999999
- castTest(
- generateBytes(),
- DataTypes.TimestampType,
- hasIncompatibleType = usingParquetExecWithIncompatTypes)
+ test("cast ByteType to TimestampType") {
+ compatibleTimezones.foreach { tz =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+ castTest(
+ generateBytes(),
+ DataTypes.TimestampType,
+ hasIncompatibleType = usingParquetExecWithIncompatTypes)
+ }
+ }
}
// CAST from ShortType
@@ -300,12 +320,15 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
testTry = false)
}
- ignore("cast ShortType to TimestampType") {
- // input: -1003, expected: 1969-12-31 15:43:17.0, actual: 1969-12-31
15:59:59.998997
- castTest(
- generateShorts(),
- DataTypes.TimestampType,
- hasIncompatibleType = usingParquetExecWithIncompatTypes)
+ test("cast ShortType to TimestampType") {
+ compatibleTimezones.foreach { tz =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+ castTest(
+ generateShorts(),
+ DataTypes.TimestampType,
+ hasIncompatibleType = usingParquetExecWithIncompatTypes)
+ }
+ }
}
// CAST from integer
@@ -363,9 +386,12 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateInts(), DataTypes.BinaryType, testAnsi = false, testTry =
false)
}
- ignore("cast IntegerType to TimestampType") {
- // input: -1000479329, expected: 1938-04-19 01:04:31.0, actual: 1969-12-31
15:43:19.520671
- castTest(generateInts(), DataTypes.TimestampType)
+ test("cast IntegerType to TimestampType") {
+ compatibleTimezones.foreach { tz =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+ castTest(generateInts(), DataTypes.TimestampType)
+ }
+ }
}
// CAST from LongType
@@ -410,9 +436,17 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
castTest(generateLongs(), DataTypes.BinaryType, testAnsi = false, testTry
= false)
}
- ignore("cast LongType to TimestampType") {
- // java.lang.ArithmeticException: long overflow
- castTest(generateLongs(), DataTypes.TimestampType)
+ test("cast LongType to TimestampType") {
+ // Cast back to long avoids java.sql.Timestamp overflow during collect()
for extreme values
+ compatibleTimezones.foreach { tz =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) {
+ withTable("t1") {
+ generateLongs().write.saveAsTable("t1")
+ val df = spark.sql("select a, cast(cast(a as timestamp) as long)
from t1")
+ checkSparkAnswerAndOperator(df)
+ }
+ }
+ }
}
// CAST from FloatType
@@ -1042,13 +1076,13 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
ignore("cast TimestampType to ShortType") {
// https://github.com/apache/datafusion-comet/issues/352
- // input: 2023-12-31 10:00:00.0, expected: -21472, actual: null]
+ // input: 2023-12-31 10:00:00.0, expected: -21472, actual: null
castTest(generateTimestamps(), DataTypes.ShortType)
}
ignore("cast TimestampType to IntegerType") {
// https://github.com/apache/datafusion-comet/issues/352
- // input: 2023-12-31 10:00:00.0, expected: 1704045600, actual: null]
+ // input: 2023-12-31 10:00:00.0, expected: 1704045600, actual: null
castTest(generateTimestamps(), DataTypes.IntegerType)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]