This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new bcd8af8e7 Add support for `DataType::Timestamp` casts in
`unwrap_cast_in_comparison` optimizer pass (#4148)
bcd8af8e7 is described below
commit bcd8af8e7cfdcafa948340a7de9a2101837eeaaf
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Nov 12 06:16:43 2022 -0500
Add support for `DataType::Timestamp` casts in `unwrap_cast_in_comparison`
optimizer pass (#4148)
* Add support for timestamp casts in unwrap_cast_in_comparison optimzier
pass
* correct comment in test
* Update datafusion/optimizer/src/unwrap_cast_in_comparison.rs
---
.../optimizer/src/unwrap_cast_in_comparison.rs | 156 ++++++++++++++++++++-
datafusion/optimizer/tests/integration-test.rs | 3 +-
2 files changed, 156 insertions(+), 3 deletions(-)
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 7ac91ae3c..28b085684 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -21,7 +21,7 @@
use crate::utils::rewrite_preserving_name;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::{
- DataType, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
+ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION,
MIN_DECIMAL_FOR_EACH_PRECISION,
};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result,
ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast};
@@ -288,6 +288,7 @@ fn is_support_data_type(data_type: &DataType) -> bool {
| DataType::Int32
| DataType::Int64
| DataType::Decimal128(_, _)
+ | DataType::Timestamp(_, _)
)
}
@@ -306,6 +307,7 @@ fn try_cast_literal_to_type(
}
let mul = match target_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
=> 1_i128,
+ DataType::Timestamp(_, _) => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
other_type => {
return Err(DataFusionError::Internal(format!(
@@ -319,6 +321,7 @@ fn try_cast_literal_to_type(
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
+ DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128),
DataType::Decimal128(precision, _) => (
// Different precision for decimal128 can store different range of
value.
// For example, the precision is 3, the max of value is `999` and
the min
@@ -338,6 +341,10 @@ fn try_cast_literal_to_type(
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
+ ScalarValue::TimestampSecond(Some(v), _) => (*v as
i128).checked_mul(mul),
+ ScalarValue::TimestampMillisecond(Some(v), _) => (*v as
i128).checked_mul(mul),
+ ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as
i128).checked_mul(mul),
+ ScalarValue::TimestampNanosecond(Some(v), _) => (*v as
i128).checked_mul(mul),
ScalarValue::Decimal128(Some(v), _, scale) => {
let lit_scale_mul = 10_i128.pow(*scale as u32);
if mul >= lit_scale_mul {
@@ -376,6 +383,18 @@ fn try_cast_literal_to_type(
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
+ DataType::Timestamp(TimeUnit::Second, tz) => {
+ ScalarValue::TimestampSecond(Some(value as i64),
tz.clone())
+ }
+ DataType::Timestamp(TimeUnit::Millisecond, tz) => {
+ ScalarValue::TimestampMillisecond(Some(value as i64),
tz.clone())
+ }
+ DataType::Timestamp(TimeUnit::Microsecond, tz) => {
+ ScalarValue::TimestampMicrosecond(Some(value as i64),
tz.clone())
+ }
+ DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
+ ScalarValue::TimestampNanosecond(Some(value as i64),
tz.clone())
+ }
DataType::Decimal128(p, s) => {
ScalarValue::Decimal128(Some(value), *p, *s)
}
@@ -629,6 +648,18 @@ mod tests {
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}
+ #[test]
+ /// Basic integration test for unwrapping casts with different timezones
+ fn test_unwrap_cast_with_timestamp_nanos() {
+ let schema = expr_test_schema();
+ // cast(ts_nano as Timestamp(Nanosecond, UTC)) <
1666612093000000000::Timestamp(Nanosecond, Utc))
+ let expr_lt = try_cast(col("ts_nano_none"), timestamp_nano_utc_type())
+ .lt(lit_timestamp_nano_utc(1666612093000000000));
+ let expected =
+
col("ts_nano_none").lt(lit_timestamp_nano_none(1666612093000000000));
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+ }
+
fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
let mut expr_rewriter = UnwrapCastExprRewriter {
schema: schema.clone(),
@@ -646,6 +677,8 @@ mod tests {
DFField::new(None, "c4", DataType::Decimal128(38, 37),
false),
DFField::new(None, "c5", DataType::Float32, false),
DFField::new(None, "c6", DataType::UInt32, false),
+ DFField::new(None, "ts_nano_none",
timestamp_nano_none_type(), false),
+ DFField::new(None, "ts_nano_utf",
timestamp_nano_utc_type(), false),
],
HashMap::new(),
)
@@ -669,13 +702,32 @@ mod tests {
lit(ScalarValue::Decimal128(Some(value), precision, scale))
}
+ fn lit_timestamp_nano_none(ts: i64) -> Expr {
+ lit(ScalarValue::TimestampNanosecond(Some(ts), None))
+ }
+
+ fn lit_timestamp_nano_utc(ts: i64) -> Expr {
+ let utc = Some("+0:00".to_string());
+ lit(ScalarValue::TimestampNanosecond(Some(ts), utc))
+ }
+
fn null_decimal(precision: u8, scale: u8) -> Expr {
lit(ScalarValue::Decimal128(None, precision, scale))
}
+ fn timestamp_nano_none_type() -> DataType {
+ DataType::Timestamp(TimeUnit::Nanosecond, None)
+ }
+
+ // this is the type that now() returns
+ fn timestamp_nano_utc_type() -> DataType {
+ let utc = Some("+0:00".to_string());
+ DataType::Timestamp(TimeUnit::Nanosecond, utc)
+ }
+
#[test]
fn test_try_cast_to_type_nulls() {
- // test values that can be cast to/from all integer types
+ // test that nulls can be cast to/from all integer types
let scalars = vec![
ScalarValue::Int8(None),
ScalarValue::Int16(None),
@@ -783,6 +835,106 @@ mod tests {
);
}
+ #[test]
+ fn test_try_cast_to_type_timestamps() {
+ for time_unit in [
+ TimeUnit::Second,
+ TimeUnit::Millisecond,
+ TimeUnit::Microsecond,
+ TimeUnit::Nanosecond,
+ ] {
+ let utc = Some("+0:00".to_string());
+ // No timezone, utc timezone
+ let (lit_tz_none, lit_tz_utc) = match time_unit {
+ TimeUnit::Second => (
+ ScalarValue::TimestampSecond(Some(12345), None),
+ ScalarValue::TimestampSecond(Some(12345), utc),
+ ),
+
+ TimeUnit::Millisecond => (
+ ScalarValue::TimestampMillisecond(Some(12345), None),
+ ScalarValue::TimestampMillisecond(Some(12345), utc),
+ ),
+
+ TimeUnit::Microsecond => (
+ ScalarValue::TimestampMicrosecond(Some(12345), None),
+ ScalarValue::TimestampMicrosecond(Some(12345), utc),
+ ),
+
+ TimeUnit::Nanosecond => (
+ ScalarValue::TimestampNanosecond(Some(12345), None),
+ ScalarValue::TimestampNanosecond(Some(12345), utc),
+ ),
+ };
+
+ // Datafusion ignores timezones for comparisons of ScalarValue
+ // so double check it here
+ assert_eq!(lit_tz_none, lit_tz_utc);
+
+ // e.g. DataType::Timestamp(_, None)
+ let dt_tz_none = lit_tz_none.get_datatype();
+
+ // e.g. DataType::Timestamp(_, Some(utc))
+ let dt_tz_utc = lit_tz_utc.get_datatype();
+
+ // None <--> None
+ expect_cast(
+ lit_tz_none.clone(),
+ dt_tz_none.clone(),
+ ExpectedCast::Value(lit_tz_none.clone()),
+ );
+
+ // None <--> Utc
+ expect_cast(
+ lit_tz_none.clone(),
+ dt_tz_utc.clone(),
+ ExpectedCast::Value(lit_tz_utc.clone()),
+ );
+
+ // Utc <--> None
+ expect_cast(
+ lit_tz_utc.clone(),
+ dt_tz_none.clone(),
+ ExpectedCast::Value(lit_tz_none.clone()),
+ );
+
+ // Utc <--> Utc
+ expect_cast(
+ lit_tz_utc.clone(),
+ dt_tz_utc.clone(),
+ ExpectedCast::Value(lit_tz_utc.clone()),
+ );
+
+ // timestamp to int64
+ expect_cast(
+ lit_tz_utc.clone(),
+ DataType::Int64,
+ ExpectedCast::Value(ScalarValue::Int64(Some(12345))),
+ );
+
+ // int64 to timestamp
+ expect_cast(
+ ScalarValue::Int64(Some(12345)),
+ dt_tz_none.clone(),
+ ExpectedCast::Value(lit_tz_none.clone()),
+ );
+
+ // int64 to timestamp
+ expect_cast(
+ ScalarValue::Int64(Some(12345)),
+ dt_tz_utc.clone(),
+ ExpectedCast::Value(lit_tz_utc.clone()),
+ );
+
+ // timestamp to string (not supported yet)
+ expect_cast(
+ lit_tz_utc.clone(),
+ DataType::LargeUtf8,
+ ExpectedCast::NoValue,
+ );
+ }
+ }
+
#[test]
fn test_try_cast_to_type_unsupported() {
// int64 to list
diff --git a/datafusion/optimizer/tests/integration-test.rs
b/datafusion/optimizer/tests/integration-test.rs
index be62ba2a5..48cd831bd 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -236,7 +236,8 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> {
// constant and compared to the column without a cast so it can be
// pushed down / pruned
let expected =
- "Projection: test.col_int32\n Filter: CAST(test.col_ts_nano_none AS
Timestamp(Nanosecond, Some(\"+00:00\"))) <
TimestampNanosecond(1666612093000000000, Some(\"+00:00\"))\
+ "Projection: test.col_int32\
+ \n Filter: test.col_ts_nano_none <
TimestampNanosecond(1666612093000000000, None)\
\n TableScan: test projection=[col_int32, col_ts_nano_none]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())