This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1f8ede570 fix: cast literal to timestamp (#5517)
1f8ede570 is described below
commit 1f8ede5701c5013500fe01ec46903fe31de92785
Author: Alex Huang <[email protected]>
AuthorDate: Mon Mar 13 16:48:30 2023 +0100
fix: cast literal to timestamp (#5517)
* fix: cast literal to timestamp
* update tests for all transformation
* handle cast between same type
* refactor cast_between_timestamp to avoid overflow
* handle overflow to None
---
.../optimizer/src/unwrap_cast_in_comparison.rs | 214 ++++++++++++++++++++-
1 file changed, 210 insertions(+), 4 deletions(-)
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 46c4d3522..a940cf272 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION,
MIN_DECIMAL_FOR_EACH_PRECISION,
};
+use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
@@ -31,6 +32,7 @@ use datafusion_expr::utils::from_plan;
use datafusion_expr::{
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
};
+use std::cmp::Ordering;
use std::sync::Arc;
/// [`UnwrapCastInComparison`] attempts to remove casts from
@@ -400,16 +402,36 @@ fn try_cast_literal_to_type(
DataType::UInt32 => ScalarValue::UInt32(Some(value as
u32)),
DataType::UInt64 => ScalarValue::UInt64(Some(value as
u64)),
DataType::Timestamp(TimeUnit::Second, tz) => {
- ScalarValue::TimestampSecond(Some(value as i64),
tz.clone())
+ let value = cast_between_timestamp(
+ lit_data_type,
+ DataType::Timestamp(TimeUnit::Second, tz.clone()),
+ value,
+ );
+ ScalarValue::TimestampSecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
- ScalarValue::TimestampMillisecond(Some(value as i64),
tz.clone())
+ let value = cast_between_timestamp(
+ lit_data_type,
+ DataType::Timestamp(TimeUnit::Millisecond,
tz.clone()),
+ value,
+ );
+ ScalarValue::TimestampMillisecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
- ScalarValue::TimestampMicrosecond(Some(value as i64),
tz.clone())
+ let value = cast_between_timestamp(
+ lit_data_type,
+ DataType::Timestamp(TimeUnit::Microsecond,
tz.clone()),
+ value,
+ );
+ ScalarValue::TimestampMicrosecond(value, tz.clone())
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
- ScalarValue::TimestampNanosecond(Some(value as i64),
tz.clone())
+ let value = cast_between_timestamp(
+ lit_data_type,
+ DataType::Timestamp(TimeUnit::Nanosecond,
tz.clone()),
+ value,
+ );
+ ScalarValue::TimestampNanosecond(value, tz.clone())
}
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
@@ -428,6 +450,32 @@ fn try_cast_literal_to_type(
}
}
+/// Cast a timestamp value from one unit to another
+fn cast_between_timestamp(from: DataType, to: DataType, value: i128) ->
Option<i64> {
+ let value = value as i64;
+ let from_scale = match from {
+ DataType::Timestamp(TimeUnit::Second, _) => 1,
+ DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
+ DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
+ _ => return Some(value),
+ };
+
+ let to_scale = match to {
+ DataType::Timestamp(TimeUnit::Second, _) => 1,
+ DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
+ DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
+ DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
+ _ => return Some(value),
+ };
+
+ match from_scale.cmp(&to_scale) {
+ Ordering::Less => value.checked_mul(to_scale / from_scale),
+ Ordering::Greater => Some(value / (from_scale / to_scale)),
+ Ordering::Equal => Some(value),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -1070,4 +1118,162 @@ mod tests {
}
}
}
+
+ #[test]
+ fn test_try_cast_literal_to_timestamp() {
+ // same timestamp
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampNanosecond(Some(123456), None),
+ &DataType::Timestamp(TimeUnit::Nanosecond, None),
+ )
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(
+ new_scalar,
+ ScalarValue::TimestampNanosecond(Some(123456), None)
+ );
+
+ // TimestampNanosecond to TimestampMicrosecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampNanosecond(Some(123456), None),
+ &DataType::Timestamp(TimeUnit::Microsecond, None),
+ )
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(
+ new_scalar,
+ ScalarValue::TimestampMicrosecond(Some(123), None)
+ );
+
+ // TimestampNanosecond to TimestampMillisecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampNanosecond(Some(123456), None),
+ &DataType::Timestamp(TimeUnit::Millisecond, None),
+ )
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0),
None));
+
+ // TimestampNanosecond to TimestampSecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampNanosecond(Some(123456), None),
+ &DataType::Timestamp(TimeUnit::Second, None),
+ )
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
+
+ // TimestampMicrosecond to TimestampNanosecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampMicrosecond(Some(123), None),
+ &DataType::Timestamp(TimeUnit::Nanosecond, None),
+ )
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(
+ new_scalar,
+ ScalarValue::TimestampNanosecond(Some(123000), None)
+ );
+
+ // TimestampMicrosecond to TimestampMillisecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampMicrosecond(Some(123), None),
+ &DataType::Timestamp(TimeUnit::Millisecond, None),
+ )
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0),
None));
+
+ // TimestampMicrosecond to TimestampSecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampMicrosecond(Some(123456789), None),
+ &DataType::Timestamp(TimeUnit::Second, None),
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
+
+ // TimestampMillisecond to TimestampNanosecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampMillisecond(Some(123), None),
+ &DataType::Timestamp(TimeUnit::Nanosecond, None),
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ new_scalar,
+ ScalarValue::TimestampNanosecond(Some(123000000), None)
+ );
+
+ // TimestampMillisecond to TimestampMicrosecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampMillisecond(Some(123), None),
+ &DataType::Timestamp(TimeUnit::Microsecond, None),
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ new_scalar,
+ ScalarValue::TimestampMicrosecond(Some(123000), None)
+ );
+ // TimestampMillisecond to TimestampSecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampMillisecond(Some(123456789), None),
+ &DataType::Timestamp(TimeUnit::Second, None),
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456),
None));
+
+ // TimestampSecond to TimestampNanosecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampSecond(Some(123), None),
+ &DataType::Timestamp(TimeUnit::Nanosecond, None),
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ new_scalar,
+ ScalarValue::TimestampNanosecond(Some(123000000000), None)
+ );
+
+ // TimestampSecond to TimestampMicrosecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampSecond(Some(123), None),
+ &DataType::Timestamp(TimeUnit::Microsecond, None),
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ new_scalar,
+ ScalarValue::TimestampMicrosecond(Some(123000000), None)
+ );
+
+ // TimestampSecond to TimestampMillisecond
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampSecond(Some(123), None),
+ &DataType::Timestamp(TimeUnit::Millisecond, None),
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(
+ new_scalar,
+ ScalarValue::TimestampMillisecond(Some(123000), None)
+ );
+
+ // overflow
+ let new_scalar = try_cast_literal_to_type(
+ &ScalarValue::TimestampSecond(Some(i64::MAX), None),
+ &DataType::Timestamp(TimeUnit::Millisecond, None),
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
+ }
}