This is an automated email from the ASF dual-hosted git repository. akurmustafa pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push: new bd705fe6bb Temporal datatype support for interval arithmetic (#5971) bd705fe6bb is described below commit bd705fe6bb6f018b71dc1f25409e45ec061fd6df Author: Berkay Şahin <124376117+berkaysynn...@users.noreply.github.com> AuthorDate: Fri Apr 14 00:16:26 2023 +0300 Temporal datatype support for interval arithmetic (#5971) * first implementation and tests of timestamp subtraction * improvement after review * postgre interval format option * random tests extended * corrections after review * operator check * flag is removed * clippy fix * toml conflict * minor changes * deterministic matches * simplifications (clippy error) * test format changed * minor test fix * Update scalar.rs * Refactoring and simplifications * Make ScalarValue support interval comparison * naming tests * macro renaming * renaming macro * ok till arrow kernel ops * macro will replace matches inside evaluate add tests macro will replace matches inside evaluate ready for review * Code refactor * retract changes in scalar and datetime * ts op interval with chrono functions * bug fix and refactor * test refactor * Enhance commenting * new binary operation logic, handling the inside errors * slt and minor changes * tz parsing excluded * replace try_binary and as_datetime, and keep timezone for ts+interval op * fix after merge * delete unused functions * ready to review * correction after merge * change match order * minor changes * simplifications * update lock file * Refactoring tests You can add a millisecond array as well, but I used Nano. * bug detected * bug fixed * update cargo * tests added * minor changes after merge * fix after merge * code simplification * Some simplifications * Update min_max.rs * arithmetics moved into macros * fix cargo.lock * remove unwraps from tests * Remove run-time string comparison from the interval min/max macro * adapt upstream changes of timezone signature --------- Co-authored-by: Mehmet Ozan Kabak <ozanka...@gmail.com> Co-authored-by: metesynnada <100111937+metesynn...@users.noreply.github.com> Co-authored-by: Mustafa Akur <mustafa.a...@synnada.ai> --- datafusion/common/src/scalar.rs | 68 ++-- .../src/physical_plan/joins/symmetric_hash_join.rs | 222 ++++++++++++- .../tests/sqllogictests/test_files/timestamps.slt | 1 + datafusion/physical-expr/src/aggregate/min_max.rs | 62 ++++ datafusion/physical-expr/src/expressions/binary.rs | 350 +++++++++++++++++++++ .../physical-expr/src/expressions/datetime.rs | 279 +++++++++++++--- .../physical-expr/src/intervals/cp_solver.rs | 18 +- .../src/intervals/interval_aritmetic.rs | 25 +- .../physical-expr/src/intervals/test_utils.rs | 55 +++- 9 files changed, 983 insertions(+), 97 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 63b6caa623..c22616883c 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1192,9 +1192,9 @@ pub fn seconds_add_array<const INTERVAL_MODE: i8>( #[inline] pub fn milliseconds_add(ts_ms: i64, scalar: &ScalarValue, sign: i32) -> Result<i64> { - let secs = ts_ms / 1000; - let nsecs = ((ts_ms % 1000) * 1_000_000) as u32; - do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_millis()) + let secs = ts_ms.div_euclid(1000); + let nsecs = ts_ms.rem_euclid(1000) * 1_000_000; + do_date_time_math(secs, nsecs as u32, scalar, sign).map(|dt| dt.timestamp_millis()) } #[inline] @@ -1203,21 +1203,18 @@ pub fn milliseconds_add_array<const INTERVAL_MODE: i8>( interval: i128, sign: i32, ) -> Result<i64> { - let mut secs = ts_ms / 1000; - let mut nsecs = ((ts_ms % 1000) * 1_000_000) as i32; - if nsecs < 0 { - secs -= 1; - nsecs += 1_000_000_000; - } + let secs = ts_ms.div_euclid(1000); + let nsecs = ts_ms.rem_euclid(1000) * 1_000_000; do_date_time_math_array::<INTERVAL_MODE>(secs, nsecs as u32, interval, sign) .map(|dt| dt.timestamp_millis()) } #[inline] pub fn microseconds_add(ts_us: i64, scalar: &ScalarValue, sign: i32) -> Result<i64> { - let secs = ts_us / 1_000_000; - let nsecs = ((ts_us % 1_000_000) * 1000) as u32; - do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_nanos() / 1000) + let secs = ts_us.div_euclid(1_000_000); + let nsecs = ts_us.rem_euclid(1_000_000) * 1_000; + do_date_time_math(secs, nsecs as u32, scalar, sign) + .map(|dt| dt.timestamp_nanos() / 1000) } #[inline] @@ -1226,21 +1223,17 @@ pub fn microseconds_add_array<const INTERVAL_MODE: i8>( interval: i128, sign: i32, ) -> Result<i64> { - let mut secs = ts_us / 1_000_000; - let mut nsecs = ((ts_us % 1_000_000) * 1000) as i32; - if nsecs < 0 { - secs -= 1; - nsecs += 1_000_000_000; - } + let secs = ts_us.div_euclid(1_000_000); + let nsecs = ts_us.rem_euclid(1_000_000) * 1_000; do_date_time_math_array::<INTERVAL_MODE>(secs, nsecs as u32, interval, sign) .map(|dt| dt.timestamp_nanos() / 1000) } #[inline] pub fn nanoseconds_add(ts_ns: i64, scalar: &ScalarValue, sign: i32) -> Result<i64> { - let secs = ts_ns / 1_000_000_000; - let nsecs = (ts_ns % 1_000_000_000) as u32; - do_date_time_math(secs, nsecs, scalar, sign).map(|dt| dt.timestamp_nanos()) + let secs = ts_ns.div_euclid(1_000_000_000); + let nsecs = ts_ns.rem_euclid(1_000_000_000); + do_date_time_math(secs, nsecs as u32, scalar, sign).map(|dt| dt.timestamp_nanos()) } #[inline] @@ -1249,12 +1242,8 @@ pub fn nanoseconds_add_array<const INTERVAL_MODE: i8>( interval: i128, sign: i32, ) -> Result<i64> { - let mut secs = ts_ns / 1_000_000_000; - let mut nsecs = (ts_ns % 1_000_000_000) as i32; - if nsecs < 0 { - secs -= 1; - nsecs += 1_000_000_000; - } + let secs = ts_ns.div_euclid(1_000_000_000); + let nsecs = ts_ns.rem_euclid(1_000_000_000); do_date_time_math_array::<INTERVAL_MODE>(secs, nsecs as u32, interval, sign) .map(|dt| dt.timestamp_nanos()) } @@ -1297,7 +1286,7 @@ fn do_date_time_math( ) -> Result<NaiveDateTime> { let prior = NaiveDateTime::from_timestamp_opt(secs, nsecs).ok_or_else(|| { DataFusionError::Internal(format!( - "Could not conert to NaiveDateTime: secs {secs} nsecs {nsecs} scalar {scalar:?} sign {sign}" + "Could not convert to NaiveDateTime: secs {secs} nsecs {nsecs} scalar {scalar:?} sign {sign}" )) })?; do_date_math(prior, scalar, sign) @@ -1312,7 +1301,7 @@ fn do_date_time_math_array<const INTERVAL_MODE: i8>( ) -> Result<NaiveDateTime> { let prior = NaiveDateTime::from_timestamp_opt(secs, nsecs).ok_or_else(|| { DataFusionError::Internal(format!( - "Could not conert to NaiveDateTime: secs {secs} nsecs {nsecs}" + "Could not convert to NaiveDateTime: secs {secs} nsecs {nsecs}" )) })?; do_date_math_array::<_, INTERVAL_MODE>(prior, interval, sign) @@ -1768,6 +1757,27 @@ impl ScalarValue { DataType::UInt64 => ScalarValue::UInt64(Some(0)), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), + DataType::Timestamp(TimeUnit::Second, tz) => { + ScalarValue::TimestampSecond(Some(0), tz.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + ScalarValue::TimestampMillisecond(Some(0), tz.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + ScalarValue::TimestampMicrosecond(Some(0), tz.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + ScalarValue::TimestampNanosecond(Some(0), tz.clone()) + } + DataType::Interval(IntervalUnit::YearMonth) => { + ScalarValue::IntervalYearMonth(Some(0)) + } + DataType::Interval(IntervalUnit::DayTime) => { + ScalarValue::IntervalDayTime(Some(0)) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + ScalarValue::IntervalMonthDayNano(Some(0)) + } _ => { return Err(DataFusionError::NotImplemented(format!( "Can't create a zero scalar from data_type \"{datatype:?}\"" diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs index 47c511a1a6..c249219033 100644 --- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs @@ -1554,17 +1554,19 @@ impl SymmetricHashJoinStream { mod tests { use std::fs::File; - use arrow::array::ArrayRef; - use arrow::array::{Int32Array, TimestampNanosecondArray}; + use arrow::array::{ArrayRef, IntervalDayTimeArray}; + use arrow::array::{Int32Array, TimestampMillisecondArray}; use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema}; + use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use arrow::util::pretty::pretty_format_batches; use rstest::*; use tempfile::TempDir; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, Column}; - use datafusion_physical_expr::intervals::test_utils::gen_conjunctive_numeric_expr; + use datafusion_physical_expr::intervals::test_utils::{ + gen_conjunctive_numeric_expr, gen_conjunctive_temporal_expr, + }; use datafusion_physical_expr::PhysicalExpr; use crate::physical_plan::joins::{ @@ -1789,6 +1791,44 @@ mod tests { _ => unreachable!(), } } + fn join_expr_tests_fixture_temporal( + expr_id: usize, + left_col: Arc<dyn PhysicalExpr>, + right_col: Arc<dyn PhysicalExpr>, + schema: &Schema, + ) -> Result<Arc<dyn PhysicalExpr>> { + match expr_id { + // constructs ((left_col - INTERVAL '100ms') > (right_col - INTERVAL '200ms')) AND ((left_col - INTERVAL '450ms') < (right_col - INTERVAL '300ms')) + 0 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::new_interval_dt(0, 100), // 100 ms + ScalarValue::new_interval_dt(0, 200), // 200 ms + ScalarValue::new_interval_dt(0, 450), // 450 ms + ScalarValue::new_interval_dt(0, 300), // 300 ms + schema, + ), + // constructs ((left_col - TIMESTAMP '2023-01-01:12.00.03') > (right_col - TIMESTAMP '2023-01-01:12.00.01')) AND ((left_col - TIMESTAMP '2023-01-01:12.00.00') < (right_col - TIMESTAMP '2023-01-01:12.00.02')) + 1 => gen_conjunctive_temporal_expr( + left_col, + right_col, + Operator::Minus, + Operator::Minus, + Operator::Minus, + Operator::Minus, + ScalarValue::TimestampMillisecond(Some(1672574403000), None), // 2023-01-01:12.00.03 + ScalarValue::TimestampMillisecond(Some(1672574401000), None), // 2023-01-01:12.00.01 + ScalarValue::TimestampMillisecond(Some(1672574400000), None), // 2023-01-01:12.00.00 + ScalarValue::TimestampMillisecond(Some(1672574402000), None), // 2023-01-01:12.00.02 + schema, + ), + _ => unreachable!(), + } + } fn build_sides_record_batches( table_size: i32, key_cardinality: (i32, i32), @@ -1833,9 +1873,15 @@ mod tests { .collect::<Vec<Option<i32>>>() })); - let time = Arc::new(TimestampNanosecondArray::from( + let time = Arc::new(TimestampMillisecondArray::from( + initial_range + .clone() + .map(|x| x as i64 + 1672531200000) // x + 2023-01-01:00.00.00 + .collect::<Vec<i64>>(), + )); + let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from( initial_range - .map(|x| 1664264591000000000 + (5000000000 * (x as i64))) + .map(|x| x as i64 * 100) // x * 100ms .collect::<Vec<i64>>(), )); @@ -1849,6 +1895,7 @@ mod tests { ("l_asc_null_first", ordered_asc_null_first.clone()), ("l_asc_null_last", ordered_asc_null_last.clone()), ("l_desc_null_first", ordered_desc_null_first.clone()), + ("li1", interval_time.clone()), ])?; let right = RecordBatch::try_from_iter(vec![ ("ra1", ordered.clone()), @@ -1860,6 +1907,7 @@ mod tests { ("r_asc_null_first", ordered_asc_null_first), ("r_asc_null_last", ordered_asc_null_last), ("r_desc_null_first", ordered_desc_null_first), + ("ri1", interval_time), ])?; Ok((left, right)) } @@ -2781,4 +2829,166 @@ mod tests { assert_eq!(left_side_joiner.visited_rows.is_empty(), should_be_empty); Ok(()) } + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn testing_with_temporal_columns( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (99, 12), + )] + cardinality: (i32, i32), + #[values(0, 1)] case_expr: usize, + ) -> Result<()> { + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + let left_sorted = vec![PhysicalSortExpr { + expr: col("lt1", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("rt1", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_batch, + right_batch, + Some(left_sorted), + Some(right_sorted), + 13, + )?; + let intermediate_schema = Schema::new(vec![ + Field::new( + "left", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new( + "right", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + ]); + let filter_expr = join_expr_tests_fixture_temporal( + case_expr, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + &intermediate_schema, + )?; + let column_indices = vec![ + ColumnIndex { + index: 3, + side: JoinSide::Left, + }, + ColumnIndex { + index: 3, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + Ok(()) + } + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn test_with_interval_columns( + #[values( + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::RightSemi, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::Full + )] + join_type: JoinType, + #[values( + (4, 5), + (99, 12), + )] + cardinality: (i32, i32), + ) -> Result<()> { + let config = SessionConfig::new().with_repartition_joins(false); + let session_ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + let left_schema = &left_batch.schema(); + let right_schema = &right_batch.schema(); + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + let left_sorted = vec![PhysicalSortExpr { + expr: col("li1", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ri1", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_batch, + right_batch, + Some(left_sorted), + Some(right_sorted), + 13, + )?; + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Interval(IntervalUnit::DayTime), false), + Field::new("right", DataType::Interval(IntervalUnit::DayTime), false), + ]); + let filter_expr = join_expr_tests_fixture_temporal( + 0, + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + &intermediate_schema, + )?; + let column_indices = vec![ + ColumnIndex { + index: 9, + side: JoinSide::Left, + }, + ColumnIndex { + index: 9, + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + experiment(left, right, Some(filter), join_type, on, task_ctx).await?; + + Ok(()) + } } diff --git a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt index f51a272a6b..07c42c377b 100644 --- a/datafusion/core/tests/sqllogictests/test_files/timestamps.slt +++ b/datafusion/core/tests/sqllogictests/test_files/timestamps.slt @@ -674,6 +674,7 @@ SELECT '2000-01-01T00:00:00'::timestamp - '2010-01-01T00:00:00'::timestamp; # Interval - Timestamp => error statement error DataFusion error: Error during planning: Interval\(MonthDayNano\) - Timestamp\(Nanosecond, None\) can't be evaluated because there isn't a common type to coerce the types to + SELECT i - ts1 from FOO; statement ok diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index 077804b1a6..e695ac400d 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -346,6 +346,29 @@ macro_rules! typed_min_max_string { }}; } +macro_rules! interval_choose_min_max { + (min) => { + std::cmp::Ordering::Greater + }; + (max) => { + std::cmp::Ordering::Less + }; +} + +macro_rules! interval_min_max { + ($OP:tt, $LHS:expr, $RHS:expr) => {{ + match $LHS.partial_cmp(&$RHS) { + Some(interval_choose_min_max!($OP)) => $RHS.clone(), + Some(_) => $LHS.clone(), + None => { + return Err(DataFusionError::Internal( + "Comparison error while computing interval min/max".to_string(), + )) + } + } + }}; +} + // min/max of two scalar values of the same type macro_rules! min_max { ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ @@ -456,6 +479,45 @@ macro_rules! min_max { ) => { typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) } + ( + ScalarValue::IntervalYearMonth(lhs), + ScalarValue::IntervalYearMonth(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) + } + ( + ScalarValue::IntervalMonthDayNano(lhs), + ScalarValue::IntervalMonthDayNano(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) + } + ( + ScalarValue::IntervalDayTime(lhs), + ScalarValue::IntervalDayTime(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalDayTime, $OP) + } + ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalMonthDayNano(_), + ) | ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalMonthDayNano(_), + ) => { + interval_min_max!($OP, $VALUE, $DELTA) + } e => { return Err(DataFusionError::Internal(format!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 001614f4da..420abe313d 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -50,10 +50,17 @@ use arrow::compute::kernels::comparison::{ eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar, lt_dyn_utf8_scalar, lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar, }; +use arrow::compute::{try_unary, unary}; use arrow::datatypes::*; use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; +use datafusion_common::scalar::{ + calculate_naives, microseconds_add, microseconds_sub, milliseconds_add, + milliseconds_sub, nanoseconds_add, nanoseconds_sub, op_dt, op_dt_mdn, op_mdn, op_ym, + op_ym_dt, op_ym_mdn, parse_timezones, seconds_add, seconds_sub, MILLISECOND_MODE, + NANOSECOND_MODE, +}; use datafusion_expr::type_coercion::{is_timestamp, is_utf8_or_large_utf8}; use kernels::{ bitwise_and, bitwise_and_scalar, bitwise_or, bitwise_or_scalar, bitwise_shift_left, @@ -1240,6 +1247,349 @@ pub fn binary( Ok(Arc::new(BinaryExpr::new(lhs, op, rhs))) } +macro_rules! sub_timestamp_macro { + ($array:expr, $rhs:expr, $caster:expr, $interval_type:ty, $opt_tz_lhs:expr, $multiplier:expr, + $opt_tz_rhs:expr, $unit_sub:expr, $naive_sub_fn:expr, $counter:expr) => {{ + let prim_array = $caster(&$array)?; + let ret: PrimitiveArray<$interval_type> = try_unary(prim_array, |lhs| { + let (parsed_lhs_tz, parsed_rhs_tz) = + (parse_timezones($opt_tz_lhs)?, parse_timezones($opt_tz_rhs)?); + let (naive_lhs, naive_rhs) = calculate_naives::<$unit_sub>( + lhs.mul_wrapping($multiplier), + parsed_lhs_tz, + $rhs.mul_wrapping($multiplier), + parsed_rhs_tz, + )?; + Ok($naive_sub_fn($counter(&naive_lhs), $counter(&naive_rhs))) + })?; + Arc::new(ret) as ArrayRef + }}; +} +/// This function handles the Timestamp - Timestamp operations, +/// where the first one is an array, and the second one is a scalar, +/// hence the result is also an array. +pub fn ts_scalar_ts_op(array: ArrayRef, scalar: &ScalarValue) -> Result<ColumnarValue> { + let ret = match (array.data_type(), scalar) { + ( + DataType::Timestamp(TimeUnit::Second, opt_tz_lhs), + ScalarValue::TimestampSecond(Some(rhs), opt_tz_rhs), + ) => { + sub_timestamp_macro!( + array, + rhs, + as_timestamp_second_array, + IntervalDayTimeType, + opt_tz_lhs.as_deref(), + 1000, + opt_tz_rhs.as_deref(), + MILLISECOND_MODE, + seconds_sub, + NaiveDateTime::timestamp + ) + } + ( + DataType::Timestamp(TimeUnit::Millisecond, opt_tz_lhs), + ScalarValue::TimestampMillisecond(Some(rhs), opt_tz_rhs), + ) => { + sub_timestamp_macro!( + array, + rhs, + as_timestamp_millisecond_array, + IntervalDayTimeType, + opt_tz_lhs.as_deref(), + 1, + opt_tz_rhs.as_deref(), + MILLISECOND_MODE, + milliseconds_sub, + NaiveDateTime::timestamp_millis + ) + } + ( + DataType::Timestamp(TimeUnit::Microsecond, opt_tz_lhs), + ScalarValue::TimestampMicrosecond(Some(rhs), opt_tz_rhs), + ) => { + sub_timestamp_macro!( + array, + rhs, + as_timestamp_microsecond_array, + IntervalMonthDayNanoType, + opt_tz_lhs.as_deref(), + 1000, + opt_tz_rhs.as_deref(), + NANOSECOND_MODE, + microseconds_sub, + NaiveDateTime::timestamp_micros + ) + } + ( + DataType::Timestamp(TimeUnit::Nanosecond, opt_tz_lhs), + ScalarValue::TimestampNanosecond(Some(rhs), opt_tz_rhs), + ) => { + sub_timestamp_macro!( + array, + rhs, + as_timestamp_nanosecond_array, + IntervalMonthDayNanoType, + opt_tz_lhs.as_deref(), + 1, + opt_tz_rhs.as_deref(), + NANOSECOND_MODE, + nanoseconds_sub, + NaiveDateTime::timestamp_nanos + ) + } + (_, _) => { + return Err(DataFusionError::Internal(format!( + "Invalid array - scalar types for Timestamp subtraction: {:?} - {:?}", + array.data_type(), + scalar.get_datatype() + ))); + } + }; + Ok(ColumnarValue::Array(ret)) +} + +macro_rules! sub_timestamp_interval_macro { + ($array:expr, $as_timestamp:expr, $ts_type:ty, $fn_op:expr, $scalar:expr, $sign:expr, $tz:expr) => {{ + let array = $as_timestamp(&$array)?; + let ret: PrimitiveArray<$ts_type> = + try_unary::<$ts_type, _, $ts_type>(array, |ts_s| { + Ok($fn_op(ts_s, $scalar, $sign)?) + })?; + Arc::new(ret.with_timezone_opt($tz.clone())) as ArrayRef + }}; +} +/// This function handles the Timestamp - Interval operations, +/// where the first one is an array, and the second one is a scalar, +/// hence the result is also an array. +pub fn ts_scalar_interval_op( + array: ArrayRef, + sign: i32, + scalar: &ScalarValue, +) -> Result<ColumnarValue> { + let ret = match array.data_type() { + DataType::Timestamp(TimeUnit::Second, tz) => { + sub_timestamp_interval_macro!( + array, + as_timestamp_second_array, + TimestampSecondType, + seconds_add, + scalar, + sign, + tz + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + sub_timestamp_interval_macro!( + array, + as_timestamp_millisecond_array, + TimestampMillisecondType, + milliseconds_add, + scalar, + sign, + tz + ) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + sub_timestamp_interval_macro!( + array, + as_timestamp_microsecond_array, + TimestampMicrosecondType, + microseconds_add, + scalar, + sign, + tz + ) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + sub_timestamp_interval_macro!( + array, + as_timestamp_nanosecond_array, + TimestampNanosecondType, + nanoseconds_add, + scalar, + sign, + tz + ) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid lhs type for Timestamp vs Interval operations: {}", + array.data_type() + )))?, + }; + Ok(ColumnarValue::Array(ret)) +} + +macro_rules! sub_interval_macro { + ($array:expr, $as_interval:expr, $interval_type:ty, $fn_op:expr, $scalar:expr, $sign:expr) => {{ + let array = $as_interval(&$array)?; + let ret: PrimitiveArray<$interval_type> = + unary(array, |lhs| $fn_op(lhs, *$scalar, $sign)); + Arc::new(ret) as ArrayRef + }}; +} +macro_rules! sub_interval_cross_macro { + ($array:expr, $as_interval:expr, $commute:expr, $fn_op:expr, $scalar:expr, $sign:expr, $t1:ty, $t2:ty) => {{ + let array = $as_interval(&$array)?; + let ret: PrimitiveArray<IntervalMonthDayNanoType> = if $commute { + unary(array, |lhs| { + $fn_op(*$scalar as $t1, lhs as $t2, $sign, $commute) + }) + } else { + unary(array, |lhs| { + $fn_op(lhs as $t1, *$scalar as $t2, $sign, $commute) + }) + }; + Arc::new(ret) as ArrayRef + }}; +} +/// This function handles the Interval - Interval operations, +/// where the first one is an array, and the second one is a scalar, +/// hence the result is also an interval array. +pub fn interval_scalar_interval_op( + array: ArrayRef, + sign: i32, + scalar: &ScalarValue, +) -> Result<ColumnarValue> { + let ret = match (array.data_type(), scalar) { + ( + DataType::Interval(IntervalUnit::YearMonth), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => { + sub_interval_macro!( + array, + as_interval_ym_array, + IntervalYearMonthType, + op_ym, + rhs, + sign + ) + } + ( + DataType::Interval(IntervalUnit::YearMonth), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => { + sub_interval_cross_macro!( + array, + as_interval_ym_array, + false, + op_ym_dt, + rhs, + sign, + i32, + i64 + ) + } + ( + DataType::Interval(IntervalUnit::YearMonth), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => { + sub_interval_cross_macro!( + array, + as_interval_ym_array, + false, + op_ym_mdn, + rhs, + sign, + i32, + i128 + ) + } + ( + DataType::Interval(IntervalUnit::DayTime), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => { + sub_interval_cross_macro!( + array, + as_interval_dt_array, + true, + op_ym_dt, + rhs, + sign, + i32, + i64 + ) + } + ( + DataType::Interval(IntervalUnit::DayTime), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => { + sub_interval_macro!( + array, + as_interval_dt_array, + IntervalDayTimeType, + op_dt, + rhs, + sign + ) + } + ( + DataType::Interval(IntervalUnit::DayTime), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => { + sub_interval_cross_macro!( + array, + as_interval_dt_array, + false, + op_dt_mdn, + rhs, + sign, + i64, + i128 + ) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + ScalarValue::IntervalYearMonth(Some(rhs)), + ) => { + sub_interval_cross_macro!( + array, + as_interval_mdn_array, + true, + op_ym_mdn, + rhs, + sign, + i32, + i128 + ) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + ScalarValue::IntervalDayTime(Some(rhs)), + ) => { + sub_interval_cross_macro!( + array, + as_interval_mdn_array, + true, + op_dt_mdn, + rhs, + sign, + i64, + i128 + ) + } + ( + DataType::Interval(IntervalUnit::MonthDayNano), + ScalarValue::IntervalMonthDayNano(Some(rhs)), + ) => { + sub_interval_macro!( + array, + as_interval_mdn_array, + IntervalMonthDayNanoType, + op_mdn, + rhs, + sign + ) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid operands for Interval vs Interval operations: {} - {}", + array.data_type(), + scalar.get_datatype(), + )))?, + }; + Ok(ColumnarValue::Array(ret)) +} + // Macros related with timestamp & interval operations macro_rules! ts_sub_op { ($lhs:ident, $rhs:ident, $lhs_tz:ident, $rhs_tz:ident, $coef:expr, $caster:expr, $op:expr, $ts_unit:expr, $mode:expr, $type_out:ty) => {{ diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index 9246074841..c2a54beceb 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; +use crate::intervals::{apply_operator, Interval}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::array::{Array, ArrayRef}; -use arrow::compute::unary; -use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Schema, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, -}; +use arrow::compute::try_unary; +use arrow::datatypes::{DataType, Date32Type, Date64Type, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::*; @@ -35,7 +34,10 @@ use std::any::Any; use std::fmt::{Display, Formatter}; use std::sync::Arc; -use super::binary::{interval_array_op, ts_array_op, ts_interval_array_op}; +use super::binary::{ + interval_array_op, interval_scalar_interval_op, ts_array_op, ts_interval_array_op, + ts_scalar_interval_op, ts_scalar_ts_op, +}; /// Perform DATE/TIME/TIMESTAMP +/ INTERVAL math #[derive(Debug)] @@ -67,6 +69,7 @@ impl DateTimeIntervalExpr { DataType::Interval(_), ) | (DataType::Timestamp(_, _), Operator::Minus, DataType::Timestamp(_, _)) + | (DataType::Interval(_), Operator::Plus, DataType::Timestamp(_, _)) | ( DataType::Interval(_), Operator::Plus | Operator::Minus, @@ -78,7 +81,7 @@ impl DateTimeIntervalExpr { input_schema: input_schema.clone(), }), (lhs, _, rhs) => Err(DataFusionError::Execution(format!( - "Invalid operation between '{lhs}' and '{rhs}' for DateIntervalExpr" + "Invalid operation {op} between '{lhs}' and '{rhs}' for DateIntervalExpr" ))), } } @@ -149,7 +152,7 @@ impl PhysicalExpr for DateTimeIntervalExpr { })) } (ColumnarValue::Array(array_lhs), ColumnarValue::Scalar(operand_rhs)) => { - evaluate_array(array_lhs, sign, &operand_rhs) + evaluate_temporal_array(array_lhs, sign, &operand_rhs) } (ColumnarValue::Array(array_lhs), ColumnarValue::Array(array_rhs)) => { @@ -162,6 +165,42 @@ impl PhysicalExpr for DateTimeIntervalExpr { } } + fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> { + // Get children intervals: + let left_interval = children[0]; + let right_interval = children[1]; + // Calculate current node's interval: + apply_operator(&self.op, left_interval, right_interval) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result<Vec<Option<Interval>>> { + // Get children intervals. Graph brings + let left_interval = children[0]; + let right_interval = children[1]; + let (left, right) = if self.op.is_comparison_operator() { + if let Interval { + lower: ScalarValue::Boolean(Some(false)), + upper: ScalarValue::Boolean(Some(false)), + } = interval + { + // TODO: We will handle strictly false clauses by negating + // the comparison operator (e.g. GT to LE, LT to GE) + // once open/closed intervals are supported. + return Ok(vec![]); + } + // Propagate the comparison operator. + propagate_comparison(&self.op, left_interval, right_interval)? + } else { + // Propagate the arithmetic operator. + propagate_arithmetic(&self.op, interval, left_interval, right_interval)? + }; + Ok(vec![left, right]) + } + fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> { vec![self.lhs.clone(), self.rhs.clone()] } @@ -188,64 +227,44 @@ impl PartialEq<dyn Any> for DateTimeIntervalExpr { } } -pub fn evaluate_array( +pub fn evaluate_temporal_array( array: ArrayRef, sign: i32, scalar: &ScalarValue, ) -> Result<ColumnarValue> { - let ret = match array.data_type() { - DataType::Date32 => { + match (array.data_type(), scalar.get_datatype()) { + // Date +- Interval + (DataType::Date32, DataType::Interval(_)) => { let array = as_date32_array(&array)?; - Arc::new(unary::<Date32Type, _, Date32Type>(array, |days| { - date32_add(days, scalar, sign).unwrap() - })) as ArrayRef + let ret = Arc::new(try_unary::<Date32Type, _, Date32Type>(array, |days| { + Ok(date32_add(days, scalar, sign)?) + })?) as ArrayRef; + Ok(ColumnarValue::Array(ret)) } - DataType::Date64 => { + (DataType::Date64, DataType::Interval(_)) => { let array = as_date64_array(&array)?; - Arc::new(unary::<Date64Type, _, Date64Type>(array, |ms| { - date64_add(ms, scalar, sign).unwrap() - })) as ArrayRef + let ret = Arc::new(try_unary::<Date64Type, _, Date64Type>(array, |ms| { + Ok(date64_add(ms, scalar, sign)?) + })?) as ArrayRef; + Ok(ColumnarValue::Array(ret)) } - DataType::Timestamp(TimeUnit::Second, _) => { - let array = as_timestamp_second_array(&array)?; - Arc::new(unary::<TimestampSecondType, _, TimestampSecondType>( - array, - |ts_s| seconds_add(ts_s, scalar, sign).unwrap(), - )) as ArrayRef + // Timestamp - Timestamp + (DataType::Timestamp(_, _), DataType::Timestamp(_, _)) if sign == -1 => { + ts_scalar_ts_op(array, scalar) } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let array = as_timestamp_millisecond_array(&array)?; - Arc::new( - unary::<TimestampMillisecondType, _, TimestampMillisecondType>( - array, - |ts_ms| milliseconds_add(ts_ms, scalar, sign).unwrap(), - ), - ) as ArrayRef - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let array = as_timestamp_microsecond_array(&array)?; - Arc::new( - unary::<TimestampMicrosecondType, _, TimestampMicrosecondType>( - array, - |ts_us| microseconds_add(ts_us, scalar, sign).unwrap(), - ), - ) as ArrayRef + // Interval +- Interval + (DataType::Interval(_), DataType::Interval(_)) => { + interval_scalar_interval_op(array, sign, scalar) } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let array = as_timestamp_nanosecond_array(&array)?; - Arc::new( - unary::<TimestampNanosecondType, _, TimestampNanosecondType>( - array, - |ts_ns| nanoseconds_add(ts_ns, scalar, sign).unwrap(), - ), - ) as ArrayRef + // Timestamp +- Interval + (DataType::Timestamp(_, _), DataType::Interval(_)) => { + ts_scalar_interval_op(array, sign, scalar) } - _ => Err(DataFusionError::Execution(format!( + (_, _) => Err(DataFusionError::Execution(format!( "Invalid lhs type for DateIntervalExpr: {}", array.data_type() )))?, - }; - Ok(ColumnarValue::Array(ret)) + } } // This function evaluates temporal array operations, such as timestamp - timestamp, interval + interval, @@ -291,6 +310,7 @@ mod tests { use crate::execution_props::ExecutionProps; use arrow::array::{ArrayRef, Date32Builder}; use arrow::datatypes::*; + use arrow_array::IntervalMonthDayNanoArray; use chrono::{Duration, NaiveDate}; use datafusion_common::delta::shift_months; use datafusion_common::{Column, Result, ToDFSchema}; @@ -685,4 +705,159 @@ mod tests { let res = cut.evaluate(&batch)?; Ok(res) } + + // In this test, ArrayRef of one element arrays is evaluated with some ScalarValues, + // aiming that evaluate_temporal_array function is working properly and shows the same + // behavior with ScalarValue arithmetic. + fn experiment( + timestamp_scalar: ScalarValue, + interval_scalar: ScalarValue, + ) -> Result<()> { + let timestamp_array = timestamp_scalar.to_array(); + let interval_array = interval_scalar.to_array(); + + // timestamp + interval + if let ColumnarValue::Array(res1) = + evaluate_temporal_array(timestamp_array.clone(), 1, &interval_scalar)? + { + let res2 = timestamp_scalar.add(&interval_scalar)?.to_array(); + assert_eq!( + &res1, &res2, + "Timestamp Scalar={} + Interval Scalar={}", + timestamp_scalar, interval_scalar + ); + } + + // timestamp - interval + if let ColumnarValue::Array(res1) = + evaluate_temporal_array(timestamp_array.clone(), -1, &interval_scalar)? + { + let res2 = timestamp_scalar.sub(&interval_scalar)?.to_array(); + assert_eq!( + &res1, &res2, + "Timestamp Scalar={} - Interval Scalar={}", + timestamp_scalar, interval_scalar + ); + } + + // timestamp - timestamp + if let ColumnarValue::Array(res1) = + evaluate_temporal_array(timestamp_array.clone(), -1, ×tamp_scalar)? + { + let res2 = timestamp_scalar.sub(×tamp_scalar)?.to_array(); + assert_eq!( + &res1, &res2, + "Timestamp Scalar={} - Timestamp Scalar={}", + timestamp_scalar, timestamp_scalar + ); + } + + // interval - interval + if let ColumnarValue::Array(res1) = + evaluate_temporal_array(interval_array.clone(), -1, &interval_scalar)? + { + let res2 = interval_scalar.sub(&interval_scalar)?.to_array(); + assert_eq!( + &res1, &res2, + "Interval Scalar={} - Interval Scalar={}", + interval_scalar, interval_scalar + ); + } + + // interval + interval + if let ColumnarValue::Array(res1) = + evaluate_temporal_array(interval_array, 1, &interval_scalar)? + { + let res2 = interval_scalar.add(&interval_scalar)?.to_array(); + assert_eq!( + &res1, &res2, + "Interval Scalar={} + Interval Scalar={}", + interval_scalar, interval_scalar + ); + } + + Ok(()) + } + #[test] + fn test_evalute_with_scalar() -> Result<()> { + // Timestamp (sec) & Interval (DayTime) + let timestamp_scalar = ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + None, + ); + let interval_scalar = ScalarValue::new_interval_dt(0, 1_000); + + experiment(timestamp_scalar, interval_scalar)?; + + // Timestamp (millisec) & Interval (DayTime) + let timestamp_scalar = ScalarValue::TimestampMillisecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_milli_opt(0, 0, 0, 0) + .unwrap() + .timestamp_millis(), + ), + None, + ); + let interval_scalar = ScalarValue::new_interval_dt(0, 1_000); + + experiment(timestamp_scalar, interval_scalar)?; + + // Timestamp (nanosec) & Interval (MonthDayNano) + let timestamp_scalar = ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_nano_opt(0, 0, 0, 0) + .unwrap() + .timestamp_nanos(), + ), + None, + ); + let interval_scalar = ScalarValue::new_interval_mdn(0, 0, 1_000); + + experiment(timestamp_scalar, interval_scalar)?; + + // Timestamp (nanosec) & Interval (MonthDayNano), negatively resulting cases + + let timestamp_scalar = ScalarValue::TimestampNanosecond( + Some( + NaiveDate::from_ymd_opt(1970, 1, 1) + .unwrap() + .and_hms_nano_opt(0, 0, 0, 000) + .unwrap() + .timestamp_nanos(), + ), + None, + ); + + Arc::new(IntervalMonthDayNanoArray::from(vec![1_000])); // 1 us + let interval_scalar = ScalarValue::new_interval_mdn(0, 0, 1_000); + + experiment(timestamp_scalar, interval_scalar)?; + + // Timestamp (sec) & Interval (YearMonth) + let timestamp_scalar = ScalarValue::TimestampSecond( + Some( + NaiveDate::from_ymd_opt(2023, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() + .timestamp(), + ), + None, + ); + let interval_scalar = ScalarValue::new_interval_ym(0, 1); + + experiment(timestamp_scalar, interval_scalar)?; + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 5e9353599e..5412d77cb2 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use arrow_schema::DataType; use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::type_coercion::binary::coerce_types; use datafusion_expr::Operator; use petgraph::graph::NodeIndex; use petgraph::stable_graph::{DefaultIx, StableGraph}; @@ -252,9 +253,14 @@ pub fn propagate_arithmetic( /// If we have expression < 0, expression must have the range [-∞, 0]. /// Currently, we only support strict inequalities since open/closed intervals /// are not implemented yet. -fn comparison_operator_target(datatype: &DataType, op: &Operator) -> Result<Interval> { - let unbounded = ScalarValue::try_from(datatype)?; - let zero = ScalarValue::new_zero(datatype)?; +fn comparison_operator_target( + left_datatype: &DataType, + op: &Operator, + right_datatype: &DataType, +) -> Result<Interval> { + let datatype = coerce_types(left_datatype, &Operator::Minus, right_datatype)?; + let unbounded = ScalarValue::try_from(&datatype)?; + let zero = ScalarValue::new_zero(&datatype)?; Ok(match *op { Operator::Gt => Interval { lower: zero, @@ -280,7 +286,11 @@ pub fn propagate_comparison( left_child: &Interval, right_child: &Interval, ) -> Result<(Option<Interval>, Option<Interval>)> { - let parent = comparison_operator_target(&left_child.get_datatype(), op)?; + let parent = comparison_operator_target( + &left_child.get_datatype(), + op, + &right_child.get_datatype(), + )?; propagate_arithmetic(&Operator::Minus, &parent, left_child, right_child) } diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs index 7fc3641b25..94ac4e9a81 100644 --- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs +++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs @@ -24,6 +24,7 @@ use std::fmt::{Display, Formatter}; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::DataType; use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::type_coercion::binary::coerce_types; use datafusion_expr::Operator; use crate::aggregate::min_max::{max, min}; @@ -202,12 +203,20 @@ impl Interval { pub fn add<T: Borrow<Interval>>(&self, other: T) -> Result<Interval> { let rhs = other.borrow(); let lower = if self.lower.is_null() || rhs.lower.is_null() { - ScalarValue::try_from(self.lower.get_datatype()) + ScalarValue::try_from(&coerce_types( + &self.get_datatype(), + &Operator::Plus, + &rhs.get_datatype(), + )?) } else { self.lower.add(&rhs.lower) }?; let upper = if self.upper.is_null() || rhs.upper.is_null() { - ScalarValue::try_from(self.upper.get_datatype()) + ScalarValue::try_from(coerce_types( + &self.get_datatype(), + &Operator::Plus, + &rhs.get_datatype(), + )?) } else { self.upper.add(&rhs.upper) }?; @@ -221,12 +230,20 @@ impl Interval { pub fn sub<T: Borrow<Interval>>(&self, other: T) -> Result<Interval> { let rhs = other.borrow(); let lower = if self.lower.is_null() || rhs.upper.is_null() { - ScalarValue::try_from(self.lower.get_datatype()) + ScalarValue::try_from(coerce_types( + &self.get_datatype(), + &Operator::Minus, + &rhs.get_datatype(), + )?) } else { self.lower.sub(&rhs.upper) }?; let upper = if self.upper.is_null() || rhs.lower.is_null() { - ScalarValue::try_from(self.upper.get_datatype()) + ScalarValue::try_from(coerce_types( + &self.get_datatype(), + &Operator::Minus, + &rhs.get_datatype(), + )?) } else { self.upper.sub(&rhs.lower) }?; diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index ba02f4ff7a..3070cb0a8f 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -19,9 +19,10 @@ use std::sync::Arc; -use crate::expressions::{BinaryExpr, Literal}; +use crate::expressions::{BinaryExpr, DateTimeIntervalExpr, Literal}; use crate::PhysicalExpr; -use datafusion_common::ScalarValue; +use arrow_schema::Schema; +use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::Operator; #[allow(clippy::too_many_arguments)] @@ -65,3 +66,53 @@ pub fn gen_conjunctive_numeric_expr( let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt, right_and_2)); Arc::new(BinaryExpr::new(left_expr, Operator::And, right_expr)) } + +#[allow(clippy::too_many_arguments)] +/// This test function generates a conjunctive statement with +/// two scalar values with the following form: +/// left_col (op_1) a > right_col (op_2) b AND left_col (op_3) c < right_col (op_4) d +pub fn gen_conjunctive_temporal_expr( + left_col: Arc<dyn PhysicalExpr>, + right_col: Arc<dyn PhysicalExpr>, + op_1: Operator, + op_2: Operator, + op_3: Operator, + op_4: Operator, + a: ScalarValue, + b: ScalarValue, + c: ScalarValue, + d: ScalarValue, + schema: &Schema, +) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> { + let left_and_1 = Arc::new(DateTimeIntervalExpr::try_new( + left_col.clone(), + op_1, + Arc::new(Literal::new(a)), + schema, + )?); + let left_and_2 = Arc::new(DateTimeIntervalExpr::try_new( + right_col.clone(), + op_2, + Arc::new(Literal::new(b)), + schema, + )?); + let right_and_1 = Arc::new(DateTimeIntervalExpr::try_new( + left_col, + op_3, + Arc::new(Literal::new(c)), + schema, + )?); + let right_and_2 = Arc::new(DateTimeIntervalExpr::try_new( + right_col, + op_4, + Arc::new(Literal::new(d)), + schema, + )?); + let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); + let right_expr = Arc::new(BinaryExpr::new(right_and_1, Operator::Lt, right_and_2)); + Ok(Arc::new(BinaryExpr::new( + left_expr, + Operator::And, + right_expr, + ))) +}