This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f8ede570 fix: cast literal to timestamp (#5517)
1f8ede570 is described below

commit 1f8ede5701c5013500fe01ec46903fe31de92785
Author: Alex Huang <[email protected]>
AuthorDate: Mon Mar 13 16:48:30 2023 +0100

    fix: cast literal to timestamp (#5517)
    
    * fix: cast literal to timestamp
    
    * update tests for all transformation
    
    * handle cast between same type
    
    * refactor cast_between_timestamp to avoid overflow
    
    * handle overflow to None
---
 .../optimizer/src/unwrap_cast_in_comparison.rs     | 214 ++++++++++++++++++++-
 1 file changed, 210 insertions(+), 4 deletions(-)

diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs 
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 46c4d3522..a940cf272 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule};
 use arrow::datatypes::{
     DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, 
MIN_DECIMAL_FOR_EACH_PRECISION,
 };
+use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
 use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
 use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
 use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
@@ -31,6 +32,7 @@ use datafusion_expr::utils::from_plan;
 use datafusion_expr::{
     binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
 };
+use std::cmp::Ordering;
 use std::sync::Arc;
 
 /// [`UnwrapCastInComparison`] attempts to remove casts from
@@ -400,16 +402,36 @@ fn try_cast_literal_to_type(
                     DataType::UInt32 => ScalarValue::UInt32(Some(value as 
u32)),
                     DataType::UInt64 => ScalarValue::UInt64(Some(value as 
u64)),
                     DataType::Timestamp(TimeUnit::Second, tz) => {
-                        ScalarValue::TimestampSecond(Some(value as i64), 
tz.clone())
+                        let value = cast_between_timestamp(
+                            lit_data_type,
+                            DataType::Timestamp(TimeUnit::Second, tz.clone()),
+                            value,
+                        );
+                        ScalarValue::TimestampSecond(value, tz.clone())
                     }
                     DataType::Timestamp(TimeUnit::Millisecond, tz) => {
-                        ScalarValue::TimestampMillisecond(Some(value as i64), 
tz.clone())
+                        let value = cast_between_timestamp(
+                            lit_data_type,
+                            DataType::Timestamp(TimeUnit::Millisecond, 
tz.clone()),
+                            value,
+                        );
+                        ScalarValue::TimestampMillisecond(value, tz.clone())
                     }
                     DataType::Timestamp(TimeUnit::Microsecond, tz) => {
-                        ScalarValue::TimestampMicrosecond(Some(value as i64), 
tz.clone())
+                        let value = cast_between_timestamp(
+                            lit_data_type,
+                            DataType::Timestamp(TimeUnit::Microsecond, 
tz.clone()),
+                            value,
+                        );
+                        ScalarValue::TimestampMicrosecond(value, tz.clone())
                     }
                     DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
-                        ScalarValue::TimestampNanosecond(Some(value as i64), 
tz.clone())
+                        let value = cast_between_timestamp(
+                            lit_data_type,
+                            DataType::Timestamp(TimeUnit::Nanosecond, 
tz.clone()),
+                            value,
+                        );
+                        ScalarValue::TimestampNanosecond(value, tz.clone())
                     }
                     DataType::Decimal128(p, s) => {
                         ScalarValue::Decimal128(Some(value), *p, *s)
@@ -428,6 +450,32 @@ fn try_cast_literal_to_type(
     }
 }
 
+/// Cast a timestamp value from one unit to another
+fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> 
Option<i64> {
+    let value = value as i64;
+    let from_scale = match from {
+        DataType::Timestamp(TimeUnit::Second, _) => 1,
+        DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
+        DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
+        DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
+        _ => return Some(value),
+    };
+
+    let to_scale = match to {
+        DataType::Timestamp(TimeUnit::Second, _) => 1,
+        DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
+        DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
+        DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
+        _ => return Some(value),
+    };
+
+    match from_scale.cmp(&to_scale) {
+        Ordering::Less => value.checked_mul(to_scale / from_scale),
+        Ordering::Greater => Some(value / (from_scale / to_scale)),
+        Ordering::Equal => Some(value),
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -1070,4 +1118,162 @@ mod tests {
             }
         }
     }
+
+    #[test]
+    fn test_try_cast_literal_to_timestamp() {
+        // same timestamp
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampNanosecond(Some(123456), None),
+            &DataType::Timestamp(TimeUnit::Nanosecond, None),
+        )
+        .unwrap()
+        .unwrap();
+
+        assert_eq!(
+            new_scalar,
+            ScalarValue::TimestampNanosecond(Some(123456), None)
+        );
+
+        // TimestampNanosecond to TimestampMicrosecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampNanosecond(Some(123456), None),
+            &DataType::Timestamp(TimeUnit::Microsecond, None),
+        )
+        .unwrap()
+        .unwrap();
+
+        assert_eq!(
+            new_scalar,
+            ScalarValue::TimestampMicrosecond(Some(123), None)
+        );
+
+        // TimestampNanosecond to TimestampMillisecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampNanosecond(Some(123456), None),
+            &DataType::Timestamp(TimeUnit::Millisecond, None),
+        )
+        .unwrap()
+        .unwrap();
+
+        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), 
None));
+
+        // TimestampNanosecond to TimestampSecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampNanosecond(Some(123456), None),
+            &DataType::Timestamp(TimeUnit::Second, None),
+        )
+        .unwrap()
+        .unwrap();
+
+        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
+
+        // TimestampMicrosecond to TimestampNanosecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampMicrosecond(Some(123), None),
+            &DataType::Timestamp(TimeUnit::Nanosecond, None),
+        )
+        .unwrap()
+        .unwrap();
+
+        assert_eq!(
+            new_scalar,
+            ScalarValue::TimestampNanosecond(Some(123000), None)
+        );
+
+        // TimestampMicrosecond to TimestampMillisecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampMicrosecond(Some(123), None),
+            &DataType::Timestamp(TimeUnit::Millisecond, None),
+        )
+        .unwrap()
+        .unwrap();
+
+        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), 
None));
+
+        // TimestampMicrosecond to TimestampSecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampMicrosecond(Some(123456789), None),
+            &DataType::Timestamp(TimeUnit::Second, None),
+        )
+        .unwrap()
+        .unwrap();
+        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
+
+        // TimestampMillisecond to TimestampNanosecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampMillisecond(Some(123), None),
+            &DataType::Timestamp(TimeUnit::Nanosecond, None),
+        )
+        .unwrap()
+        .unwrap();
+        assert_eq!(
+            new_scalar,
+            ScalarValue::TimestampNanosecond(Some(123000000), None)
+        );
+
+        // TimestampMillisecond to TimestampMicrosecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampMillisecond(Some(123), None),
+            &DataType::Timestamp(TimeUnit::Microsecond, None),
+        )
+        .unwrap()
+        .unwrap();
+        assert_eq!(
+            new_scalar,
+            ScalarValue::TimestampMicrosecond(Some(123000), None)
+        );
+        // TimestampMillisecond to TimestampSecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampMillisecond(Some(123456789), None),
+            &DataType::Timestamp(TimeUnit::Second, None),
+        )
+        .unwrap()
+        .unwrap();
+        assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), 
None));
+
+        // TimestampSecond to TimestampNanosecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampSecond(Some(123), None),
+            &DataType::Timestamp(TimeUnit::Nanosecond, None),
+        )
+        .unwrap()
+        .unwrap();
+        assert_eq!(
+            new_scalar,
+            ScalarValue::TimestampNanosecond(Some(123000000000), None)
+        );
+
+        // TimestampSecond to TimestampMicrosecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampSecond(Some(123), None),
+            &DataType::Timestamp(TimeUnit::Microsecond, None),
+        )
+        .unwrap()
+        .unwrap();
+        assert_eq!(
+            new_scalar,
+            ScalarValue::TimestampMicrosecond(Some(123000000), None)
+        );
+
+        // TimestampSecond to TimestampMillisecond
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampSecond(Some(123), None),
+            &DataType::Timestamp(TimeUnit::Millisecond, None),
+        )
+        .unwrap()
+        .unwrap();
+        assert_eq!(
+            new_scalar,
+            ScalarValue::TimestampMillisecond(Some(123000), None)
+        );
+
+        // overflow
+        let new_scalar = try_cast_literal_to_type(
+            &ScalarValue::TimestampSecond(Some(i64::MAX), None),
+            &DataType::Timestamp(TimeUnit::Millisecond, None),
+        )
+        .unwrap()
+        .unwrap();
+        assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
+    }
 }

Reply via email to