This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 92bea9931a Float support on interval artihmetics (#6048)
92bea9931a is described below
commit 92bea9931acdc697530724a6535a97b8f62e15b0
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Thu Apr 20 15:46:11 2023 +0300
Float support on interval artihmetics (#6048)
* Initial impl
* Testes float support
* Windows platform support
* Correct the platform
* Update Cargo.lock
* Resolving merge problems
* Code refactor
* Merge resolution
* Merge conflicts
* Separate rounding mode code into a module
* Refactor IA tests, fix doctest
* Fix Windows tests
* Simplify test macros
* Simplify cp_solver tests
* Latests improvements on code
---------
Co-authored-by: Metehan Yıldırım
<[email protected]>
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
datafusion-cli/Cargo.lock | 1 +
.../src/physical_plan/joins/symmetric_hash_join.rs | 424 ++++++++----
datafusion/physical-expr/Cargo.toml | 1 +
.../physical-expr/src/intervals/cp_solver.rs | 741 ++++++++++-----------
.../src/intervals/interval_aritmetic.rs | 133 +++-
datafusion/physical-expr/src/intervals/mod.rs | 4 +-
datafusion/physical-expr/src/intervals/rounding.rs | 401 +++++++++++
.../physical-expr/src/intervals/test_utils.rs | 35 +-
8 files changed, 1152 insertions(+), 588 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index a0a7137653..bf7e7934a0 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -834,6 +834,7 @@ dependencies = [
"indexmap",
"itertools",
"lazy_static",
+ "libc",
"md-5",
"paste",
"petgraph",
diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
index 0cfb79141e..23d3d70848 100644
--- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
@@ -1548,7 +1548,7 @@ impl SymmetricHashJoinStream {
mod tests {
use std::fs::File;
- use arrow::array::{ArrayRef, IntervalDayTimeArray};
+ use arrow::array::{ArrayRef, Float64Array, IntervalDayTimeArray};
use arrow::array::{Int32Array, TimestampMillisecondArray};
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit};
@@ -1559,7 +1559,7 @@ mod tests {
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, col, Column};
use datafusion_physical_expr::intervals::test_utils::{
- gen_conjunctive_numeric_expr, gen_conjunctive_temporal_expr,
+ gen_conjunctive_numerical_expr, gen_conjunctive_temporal_expr,
};
use datafusion_physical_expr::PhysicalExpr;
@@ -1711,127 +1711,184 @@ mod tests {
Ok(result)
}
- fn join_expr_tests_fixture(
- expr_id: usize,
- left_col: Arc<dyn PhysicalExpr>,
- right_col: Arc<dyn PhysicalExpr>,
- ) -> Arc<dyn PhysicalExpr> {
- match expr_id {
- // left_col + 1 > right_col + 5 AND left_col + 3 < right_col + 10
- 0 => gen_conjunctive_numeric_expr(
- left_col,
- right_col,
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- 1,
- 5,
- 3,
- 10,
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10
- 1 => gen_conjunctive_numeric_expr(
- left_col,
- right_col,
- Operator::Minus,
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- 1,
- 5,
- 3,
- 10,
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10
- 2 => gen_conjunctive_numeric_expr(
- left_col,
- right_col,
- Operator::Minus,
- Operator::Plus,
- Operator::Minus,
- Operator::Plus,
- 1,
- 5,
- 3,
- 10,
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 10 > right_col - 5 AND left_col - 3 < right_col + 10
- 3 => gen_conjunctive_numeric_expr(
- left_col,
- right_col,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- Operator::Plus,
- 10,
- 5,
- 3,
- 10,
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 10 > right_col - 5 AND left_col - 30 < right_col - 3
- 4 => gen_conjunctive_numeric_expr(
- left_col,
- right_col,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- 10,
- 5,
- 30,
- 3,
- (Operator::Gt, Operator::Lt),
- ),
- // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3
- 5 => gen_conjunctive_numeric_expr(
- left_col,
- right_col,
- Operator::Minus,
- Operator::Plus,
- Operator::Plus,
- Operator::Minus,
- 2,
- 5,
- 7,
- 3,
- (Operator::GtEq, Operator::LtEq),
- ),
- // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col
- 39
- 6 => gen_conjunctive_numeric_expr(
- left_col,
- right_col,
- Operator::Plus,
- Operator::Minus,
- Operator::Plus,
- Operator::Plus,
- 28,
- 11,
- 21,
- 39,
- (Operator::Gt, Operator::LtEq),
- ),
- // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col
- 39
- 7 => gen_conjunctive_numeric_expr(
- left_col,
- right_col,
- Operator::Plus,
- Operator::Minus,
- Operator::Minus,
- Operator::Plus,
- 28,
- 11,
- 21,
- 39,
- (Operator::GtEq, Operator::Lt),
- ),
- _ => unreachable!(),
+ // It creates join filters for different type of fields for testing.
+ macro_rules! join_expr_tests {
+ ($func_name:ident, $type:ty, $SCALAR:ident) => {
+ fn $func_name(
+ expr_id: usize,
+ left_col: Arc<dyn PhysicalExpr>,
+ right_col: Arc<dyn PhysicalExpr>,
+ ) -> Arc<dyn PhysicalExpr> {
+ match expr_id {
+ // left_col + 1 > right_col + 5 AND left_col + 3 <
right_col + 10
+ 0 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 1 > right_col + 5 AND left_col + 3 <
right_col + 10
+ 1 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 1 > right_col + 5 AND left_col - 3 <
right_col + 10
+ 2 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Minus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 10 > right_col - 5 AND left_col - 3 <
right_col + 10
+ 3 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 10 > right_col - 5 AND left_col - 30 <
right_col - 3
+ 4 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ ),
+ ScalarValue::$SCALAR(Some(10 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(30 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ (Operator::Gt, Operator::Lt),
+ ),
+ // left_col - 2 >= right_col - 5 AND left_col - 7 <=
right_col - 3
+ 5 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Minus,
+ ),
+ ScalarValue::$SCALAR(Some(2 as $type)),
+ ScalarValue::$SCALAR(Some(5 as $type)),
+ ScalarValue::$SCALAR(Some(7 as $type)),
+ ScalarValue::$SCALAR(Some(3 as $type)),
+ (Operator::GtEq, Operator::LtEq),
+ ),
+ // left_col - 28 >= right_col - 11 AND left_col - 21 <=
right_col - 39
+ 6 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Plus,
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(28 as $type)),
+ ScalarValue::$SCALAR(Some(11 as $type)),
+ ScalarValue::$SCALAR(Some(21 as $type)),
+ ScalarValue::$SCALAR(Some(39 as $type)),
+ (Operator::Gt, Operator::LtEq),
+ ),
+ // left_col - 28 >= right_col - 11 AND left_col - 21 <=
right_col - 39
+ 7 => gen_conjunctive_numerical_expr(
+ left_col,
+ right_col,
+ (
+ Operator::Plus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(28 as $type)),
+ ScalarValue::$SCALAR(Some(11 as $type)),
+ ScalarValue::$SCALAR(Some(21 as $type)),
+ ScalarValue::$SCALAR(Some(39 as $type)),
+ (Operator::GtEq, Operator::Lt),
+ ),
+ _ => panic!("No case"),
+ }
+ }
+ };
+ }
+
+ join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32);
+ join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64);
+
+ use rand::rngs::StdRng;
+ use rand::{Rng, SeedableRng};
+ use std::iter::Iterator;
+
+ struct AscendingRandomFloatIterator {
+ prev: f64,
+ max: f64,
+ rng: StdRng,
+ }
+
+ impl AscendingRandomFloatIterator {
+ fn new(min: f64, max: f64) -> Self {
+ let mut rng = StdRng::seed_from_u64(42);
+ let initial = rng.gen_range(min..max);
+ AscendingRandomFloatIterator {
+ prev: initial,
+ max,
+ rng,
+ }
}
}
+
+ impl Iterator for AscendingRandomFloatIterator {
+ type Item = f64;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let value = self.rng.gen_range(self.prev..self.max);
+ self.prev = value;
+ Some(value)
+ }
+ }
+
fn join_expr_tests_fixture_temporal(
expr_id: usize,
left_col: Arc<dyn PhysicalExpr>,
@@ -1887,12 +1944,18 @@ mod tests {
let cardinality = Arc::new(Int32Array::from_iter(
initial_range.clone().map(|x| x % 4).collect::<Vec<i32>>(),
));
- let cardinality_key = Arc::new(Int32Array::from_iter(
+ let cardinality_key_left = Arc::new(Int32Array::from_iter(
initial_range
.clone()
.map(|x| x % key_cardinality.0)
.collect::<Vec<i32>>(),
));
+ let cardinality_key_right = Arc::new(Int32Array::from_iter(
+ initial_range
+ .clone()
+ .map(|x| x % key_cardinality.1)
+ .collect::<Vec<i32>>(),
+ ));
let ordered_asc_null_first = Arc::new(Int32Array::from_iter({
std::iter::repeat(None)
.take(index as usize)
@@ -1926,10 +1989,15 @@ mod tests {
.collect::<Vec<i64>>(),
));
+ let float_asc = Arc::new(Float64Array::from_iter_values(
+ AscendingRandomFloatIterator::new(0., table_size as f64)
+ .take(table_size as usize),
+ ));
+
let left = RecordBatch::try_from_iter(vec![
("la1", ordered.clone()),
("lb1", cardinality.clone()),
- ("lc1", cardinality_key.clone()),
+ ("lc1", cardinality_key_left),
("lt1", time.clone()),
("la2", ordered.clone()),
("la1_des", ordered_des.clone()),
@@ -1937,11 +2005,12 @@ mod tests {
("l_asc_null_last", ordered_asc_null_last.clone()),
("l_desc_null_first", ordered_desc_null_first.clone()),
("li1", interval_time.clone()),
+ ("l_float", float_asc.clone()),
])?;
let right = RecordBatch::try_from_iter(vec![
("ra1", ordered.clone()),
("rb1", cardinality),
- ("rc1", cardinality_key),
+ ("rc1", cardinality_key_right),
("rt1", time),
("ra2", ordered),
("ra1_des", ordered_des),
@@ -1949,6 +2018,7 @@ mod tests {
("r_asc_null_last", ordered_asc_null_last),
("r_desc_null_first", ordered_desc_null_first),
("ri1", interval_time),
+ ("r_float", float_asc),
])?;
Ok((left, right))
}
@@ -2140,7 +2210,7 @@ mod tests {
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
- let filter_expr = join_expr_tests_fixture(
+ let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
@@ -2201,7 +2271,7 @@ mod tests {
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
- let filter_expr = join_expr_tests_fixture(
+ let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
@@ -2312,7 +2382,7 @@ mod tests {
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
- let filter_expr = join_expr_tests_fixture(
+ let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
@@ -2537,7 +2607,7 @@ mod tests {
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
- let filter_expr = join_expr_tests_fixture(
+ let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
@@ -2600,7 +2670,7 @@ mod tests {
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
- let filter_expr = join_expr_tests_fixture(
+ let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
@@ -2664,7 +2734,7 @@ mod tests {
Field::new("left", DataType::Int32, true),
Field::new("right", DataType::Int32, true),
]);
- let filter_expr = join_expr_tests_fixture(
+ let filter_expr = join_expr_tests_fixture_i32(
case_expr,
col("left", &intermediate_schema)?,
col("right", &intermediate_schema)?,
@@ -2802,17 +2872,19 @@ mod tests {
Field::new("0", DataType::Int32, true),
Field::new("1", DataType::Int32, true),
]);
- let filter_expr = gen_conjunctive_numeric_expr(
+ let filter_expr = gen_conjunctive_numerical_expr(
col("0", &intermediate_schema)?,
col("1", &intermediate_schema)?,
- Operator::Plus,
- Operator::Minus,
- Operator::Plus,
- Operator::Plus,
- 0,
- 3,
- 0,
- 3,
+ (
+ Operator::Plus,
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::Int32(Some(0)),
+ ScalarValue::Int32(Some(3)),
+ ScalarValue::Int32(Some(0)),
+ ScalarValue::Int32(Some(3)),
(Operator::Gt, Operator::Lt),
);
let column_indices = vec![
@@ -3033,4 +3105,78 @@ mod tests {
Ok(())
}
+
+ #[rstest]
+ #[tokio::test(flavor = "multi_thread")]
+ async fn testing_ascending_float_pruning(
+ #[values(
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::RightSemi,
+ JoinType::LeftSemi,
+ JoinType::LeftAnti,
+ JoinType::RightAnti,
+ JoinType::Full
+ )]
+ join_type: JoinType,
+ #[values(
+ (4, 5),
+ (99, 12),
+ )]
+ cardinality: (i32, i32),
+ #[values(0, 1, 2, 3, 4, 5, 6, 7)] case_expr: usize,
+ ) -> Result<()> {
+ let config = SessionConfig::new().with_repartition_joins(false);
+ let session_ctx = SessionContext::with_config(config);
+ let task_ctx = session_ctx.task_ctx();
+ let (left_batch, right_batch) =
+ build_sides_record_batches(TABLE_SIZE, cardinality)?;
+ let left_schema = &left_batch.schema();
+ let right_schema = &right_batch.schema();
+ let left_sorted = vec![PhysicalSortExpr {
+ expr: col("l_float", left_schema)?,
+ options: SortOptions::default(),
+ }];
+ let right_sorted = vec![PhysicalSortExpr {
+ expr: col("r_float", right_schema)?,
+ options: SortOptions::default(),
+ }];
+ let (left, right) = create_memory_table(
+ left_batch,
+ right_batch,
+ Some(left_sorted),
+ Some(right_sorted),
+ 13,
+ )?;
+
+ let on = vec![(
+ Column::new_with_schema("lc1", left_schema)?,
+ Column::new_with_schema("rc1", right_schema)?,
+ )];
+
+ let intermediate_schema = Schema::new(vec![
+ Field::new("left", DataType::Float64, true),
+ Field::new("right", DataType::Float64, true),
+ ]);
+ let filter_expr = join_expr_tests_fixture_f64(
+ case_expr,
+ col("left", &intermediate_schema)?,
+ col("right", &intermediate_schema)?,
+ );
+ let column_indices = vec![
+ ColumnIndex {
+ index: 10, // l_float
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 10, // r_float
+ side: JoinSide::Right,
+ },
+ ];
+ let filter = JoinFilter::new(filter_expr, column_indices,
intermediate_schema);
+
+ experiment(left, right, Some(filter), join_type, on, task_ctx).await?;
+ Ok(())
+ }
}
diff --git a/datafusion/physical-expr/Cargo.toml
b/datafusion/physical-expr/Cargo.toml
index b28ad534fb..31484bf793 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -58,6 +58,7 @@ hashbrown = { version = "0.13", features = ["raw"] }
indexmap = "1.9.2"
itertools = { version = "0.10", features = ["use_std"] }
lazy_static = { version = "^1.4.0" }
+libc = "0.2.140"
md-5 = { version = "^0.10.0", optional = true }
paste = "^1.0"
petgraph = "0.6.2"
diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs
b/datafusion/physical-expr/src/intervals/cp_solver.rs
index 65c8850b39..3a682049a0 100644
--- a/datafusion/physical-expr/src/intervals/cp_solver.rs
+++ b/datafusion/physical-expr/src/intervals/cp_solver.rs
@@ -552,10 +552,10 @@ pub fn check_support(expr: &Arc<dyn PhysicalExpr>) ->
bool {
#[cfg(test)]
mod tests {
use super::*;
- use crate::intervals::test_utils::gen_conjunctive_numeric_expr;
use itertools::Itertools;
use crate::expressions::{BinaryExpr, Column};
+ use crate::intervals::test_utils::gen_conjunctive_numerical_expr;
use datafusion_common::ScalarValue;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
@@ -564,31 +564,19 @@ mod tests {
fn experiment(
expr: Arc<dyn PhysicalExpr>,
exprs_with_interval: (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>),
- left_interval: (Option<i32>, Option<i32>),
- right_interval: (Option<i32>, Option<i32>),
- left_expected: (Option<i32>, Option<i32>),
- right_expected: (Option<i32>, Option<i32>),
+ left_interval: Interval,
+ right_interval: Interval,
+ left_expected: Interval,
+ right_expected: Interval,
result: PropagationResult,
) -> Result<()> {
let col_stats = vec![
- (
- exprs_with_interval.0.clone(),
- Interval::make(left_interval.0, left_interval.1, (false,
false)),
- ),
- (
- exprs_with_interval.1.clone(),
- Interval::make(right_interval.0, right_interval.1, (false,
false)),
- ),
+ (exprs_with_interval.0.clone(), left_interval),
+ (exprs_with_interval.1.clone(), right_interval),
];
let expected = vec![
- (
- exprs_with_interval.0.clone(),
- Interval::make(left_expected.0, left_expected.1, (false,
false)),
- ),
- (
- exprs_with_interval.1.clone(),
- Interval::make(right_expected.0, right_expected.1, (false,
false)),
- ),
+ (exprs_with_interval.0.clone(), left_expected),
+ (exprs_with_interval.1.clone(), right_expected),
];
let mut graph = ExprIntervalGraph::try_new(expr)?;
let expr_indexes = graph
@@ -608,81 +596,71 @@ mod tests {
let exp_result = graph.update_ranges(&mut col_stat_nodes[..])?;
assert_eq!(exp_result, result);
col_stat_nodes.iter().zip(expected_nodes.iter()).for_each(
- |((_, res), (_, expected))| {
- // NOTE: These randomized tests only check the correnctness of
- // endpoint values, not open/closedness.
- assert_eq!(res.lower.value, expected.lower.value);
- assert_eq!(res.upper.value, expected.upper.value);
+ |((_, calculated_interval_node), (_, expected))| {
+ // NOTE: These randomized tests only check for conservative
containment,
+ // not openness/closedness of endpoints.
+ assert!(calculated_interval_node.lower.value <=
expected.lower.value);
+ assert!(calculated_interval_node.upper.value >=
expected.upper.value);
},
);
Ok(())
}
- fn generate_case<const ASC: bool>(
- expr: Arc<dyn PhysicalExpr>,
- left_col: Arc<dyn PhysicalExpr>,
- right_col: Arc<dyn PhysicalExpr>,
- seed: u64,
- expr_left: i32,
- expr_right: i32,
- ) -> Result<()> {
- let mut r = StdRng::seed_from_u64(seed);
-
- let (left_interval, right_interval, left_waited, right_waited) = if
ASC {
- let left = (Some(r.gen_range(0..1000)), None);
- let right = (Some(r.gen_range(0..1000)), None);
- (
- left,
- right,
- (
- Some(std::cmp::max(left.0.unwrap(), right.0.unwrap() +
expr_left)),
- None,
- ),
- (
- Some(std::cmp::max(
- right.0.unwrap(),
- left.0.unwrap() + expr_right,
- )),
- None,
- ),
- )
- } else {
- let left = (None, Some(r.gen_range(0..1000)));
- let right = (None, Some(r.gen_range(0..1000)));
- (
- left,
- right,
- (
- None,
- Some(std::cmp::min(left.1.unwrap(), right.1.unwrap() +
expr_left)),
- ),
- (
- None,
- Some(std::cmp::min(
- right.1.unwrap(),
- left.1.unwrap() + expr_right,
- )),
- ),
- )
+ macro_rules! generate_cases {
+ ($FUNC_NAME:ident, $TYPE:ty, $SCALAR:ident) => {
+ fn $FUNC_NAME<const ASC: bool>(
+ expr: Arc<dyn PhysicalExpr>,
+ left_col: Arc<dyn PhysicalExpr>,
+ right_col: Arc<dyn PhysicalExpr>,
+ seed: u64,
+ expr_left: $TYPE,
+ expr_right: $TYPE,
+ ) -> Result<()> {
+ let mut r = StdRng::seed_from_u64(seed);
+
+ let (left_given, right_given, left_expected, right_expected) =
if ASC {
+ let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
+ let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
+ (
+ (Some(left), None),
+ (Some(right), None),
+ (Some(<$TYPE>::max(left, right + expr_left)), None),
+ (Some(<$TYPE>::max(right, left + expr_right)), None),
+ )
+ } else {
+ let left = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
+ let right = r.gen_range((0 as $TYPE)..(1000 as $TYPE));
+ (
+ (None, Some(left)),
+ (None, Some(right)),
+ (None, Some(<$TYPE>::min(left, right + expr_left))),
+ (None, Some(<$TYPE>::min(right, left + expr_right))),
+ )
+ };
+
+ experiment(
+ expr,
+ (left_col, right_col),
+ Interval::make(left_given.0, left_given.1, (true, true)),
+ Interval::make(right_given.0, right_given.1, (true, true)),
+ Interval::make(left_expected.0, left_expected.1, (true,
true)),
+ Interval::make(right_expected.0, right_expected.1, (true,
true)),
+ PropagationResult::Success,
+ )
+ }
};
- experiment(
- expr,
- (left_col, right_col),
- left_interval,
- right_interval,
- left_waited,
- right_waited,
- PropagationResult::Success,
- )?;
- Ok(())
}
+ generate_cases!(generate_case_i32, i32, Int32);
+ generate_cases!(generate_case_i64, i64, Int64);
+ generate_cases!(generate_case_f32, f32, Float32);
+ generate_cases!(generate_case_f64, f64, Float64);
#[test]
fn testing_not_possible() -> Result<()> {
let left_col = Arc::new(Column::new("left_watermark", 0));
let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark > right_watermark + 5
+ // left_watermark > right_watermark + 5
let left_and_1 = Arc::new(BinaryExpr::new(
left_col.clone(),
Operator::Plus,
@@ -692,341 +670,293 @@ mod tests {
experiment(
expr,
(left_col, right_col),
- (Some(10), Some(20)),
- (Some(100), None),
- (Some(10), Some(20)),
- (Some(100), None),
+ Interval::make(Some(10), Some(20), (true, true)),
+ Interval::make(Some(100), None, (true, true)),
+ Interval::make(Some(10), Some(20), (true, true)),
+ Interval::make(Some(100), None, (true, true)),
PropagationResult::Infeasible,
- )?;
- Ok(())
- }
-
- #[rstest]
- #[test]
- fn case_1(
- #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64,
- ) -> Result<()> {
- let left_col = Arc::new(Column::new("left_watermark", 0));
- let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark + 1 > right_watermark + 11 AND left_watermark + 3 <
right_watermark + 33
- let expr = gen_conjunctive_numeric_expr(
- left_col.clone(),
- right_col.clone(),
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- 1,
- 11,
- 3,
- 33,
- (Operator::Gt, Operator::Lt),
- );
- // l > r + 10 AND r > l - 30
- let l_gt_r = 10;
- let r_gt_l = -30;
- generate_case::<true>(
- expr.clone(),
- left_col.clone(),
- right_col.clone(),
- seed,
- l_gt_r,
- r_gt_l,
- )?;
- // Descending tests
- // r < l - 10 AND l < r + 30
- let r_lt_l = -l_gt_r;
- let l_lt_r = -r_gt_l;
- generate_case::<false>(expr, left_col, right_col, seed, l_lt_r,
r_lt_l)?;
-
- Ok(())
- }
- #[rstest]
- #[test]
- fn case_2(
- #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64,
- ) -> Result<()> {
- let left_col = Arc::new(Column::new("left_watermark", 0));
- let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark - 1 > right_watermark + 5 AND left_watermark + 3 <
right_watermark + 10
- let expr = gen_conjunctive_numeric_expr(
- left_col.clone(),
- right_col.clone(),
- Operator::Minus,
- Operator::Plus,
- Operator::Plus,
- Operator::Plus,
- 1,
- 5,
- 3,
- 10,
- (Operator::Gt, Operator::Lt),
- );
- // l > r + 6 AND r > l - 7
- let l_gt_r = 6;
- let r_gt_l = -7;
- generate_case::<true>(
- expr.clone(),
- left_col.clone(),
- right_col.clone(),
- seed,
- l_gt_r,
- r_gt_l,
- )?;
- // Descending tests
- // r < l - 6 AND l < r + 7
- let r_lt_l = -l_gt_r;
- let l_lt_r = -r_gt_l;
- generate_case::<false>(expr, left_col, right_col, seed, l_lt_r,
r_lt_l)?;
-
- Ok(())
+ )
}
- #[rstest]
- #[test]
- fn case_3(
- #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64,
- ) -> Result<()> {
- let left_col = Arc::new(Column::new("left_watermark", 0));
- let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark - 1 > right_watermark + 5 AND left_watermark - 3 <
right_watermark + 10
- let expr = gen_conjunctive_numeric_expr(
- left_col.clone(),
- right_col.clone(),
- Operator::Minus,
- Operator::Plus,
- Operator::Minus,
- Operator::Plus,
- 1,
- 5,
- 3,
- 10,
- (Operator::Gt, Operator::Lt),
- );
- // l > r + 6 AND r > l - 13
- let l_gt_r = 6;
- let r_gt_l = -13;
- generate_case::<true>(
- expr.clone(),
- left_col.clone(),
- right_col.clone(),
- seed,
- l_gt_r,
- r_gt_l,
- )?;
- // Descending tests
- // r < l - 6 AND l < r + 13
- let r_lt_l = -l_gt_r;
- let l_lt_r = -r_gt_l;
- generate_case::<false>(expr, left_col, right_col, seed, l_lt_r,
r_lt_l)?;
-
- Ok(())
+ macro_rules! integer_float_case_1 {
+ ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty,
$SCALAR:ident) => {
+ #[rstest]
+ #[test]
+ fn $TEST_FUNC_NAME(
+ #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215,
4123)]
+ seed: u64,
+ #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
+ #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
+ ) -> Result<()> {
+ let left_col = Arc::new(Column::new("left_watermark", 0));
+ let right_col = Arc::new(Column::new("right_watermark", 0));
+
+ // left_watermark + 1 > right_watermark + 11 AND
left_watermark + 3 < right_watermark + 33
+ let expr = gen_conjunctive_numerical_expr(
+ left_col.clone(),
+ right_col.clone(),
+ (
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $TYPE)),
+ ScalarValue::$SCALAR(Some(11 as $TYPE)),
+ ScalarValue::$SCALAR(Some(3 as $TYPE)),
+ ScalarValue::$SCALAR(Some(33 as $TYPE)),
+ (greater_op, less_op),
+ );
+ // l > r + 10 AND r > l - 30
+ let l_gt_r = 10 as $TYPE;
+ let r_gt_l = -30 as $TYPE;
+ $GENERATE_CASE_FUNC_NAME::<true>(
+ expr.clone(),
+ left_col.clone(),
+ right_col.clone(),
+ seed,
+ l_gt_r,
+ r_gt_l,
+ )?;
+ // Descending tests
+ // r < l - 10 AND l < r + 30
+ let r_lt_l = -l_gt_r;
+ let l_lt_r = -r_gt_l;
+ $GENERATE_CASE_FUNC_NAME::<false>(
+ expr, left_col, right_col, seed, l_lt_r, r_lt_l,
+ )
+ }
+ };
}
- #[rstest]
- #[test]
- fn case_4(
- #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64,
- ) -> Result<()> {
- let left_col = Arc::new(Column::new("left_watermark", 0));
- let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark - 10 > right_watermark - 5 AND left_watermark - 3 <
right_watermark + 10
- let expr = gen_conjunctive_numeric_expr(
- left_col.clone(),
- right_col.clone(),
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- Operator::Plus,
- 10,
- 5,
- 3,
- 10,
- (Operator::Gt, Operator::Lt),
- );
- // l > r + 5 AND r > l - 13
- let l_gt_r = 5;
- let r_gt_l = -13;
- generate_case::<true>(
- expr.clone(),
- left_col.clone(),
- right_col.clone(),
- seed,
- l_gt_r,
- r_gt_l,
- )?;
- // Descending tests
- // r < l - 5 AND l < r + 13
- let r_lt_l = -l_gt_r;
- let l_lt_r = -r_gt_l;
- generate_case::<false>(expr, left_col, right_col, seed, l_lt_r,
r_lt_l)?;
- Ok(())
+ integer_float_case_1!(case_1_i32, generate_case_i32, i32, Int32);
+ integer_float_case_1!(case_1_i64, generate_case_i64, i64, Int64);
+ integer_float_case_1!(case_1_f64, generate_case_f64, f64, Float64);
+ integer_float_case_1!(case_1_f32, generate_case_f32, f32, Float32);
+
+ macro_rules! integer_float_case_2 {
+ ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty,
$SCALAR:ident) => {
+ #[rstest]
+ #[test]
+ fn $TEST_FUNC_NAME(
+ #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215,
4123)]
+ seed: u64,
+ #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
+ #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
+ ) -> Result<()> {
+ let left_col = Arc::new(Column::new("left_watermark", 0));
+ let right_col = Arc::new(Column::new("right_watermark", 0));
+
+ // left_watermark - 1 > right_watermark + 5 AND left_watermark
+ 3 < right_watermark + 10
+ let expr = gen_conjunctive_numerical_expr(
+ left_col.clone(),
+ right_col.clone(),
+ (
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Plus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $TYPE)),
+ ScalarValue::$SCALAR(Some(5 as $TYPE)),
+ ScalarValue::$SCALAR(Some(3 as $TYPE)),
+ ScalarValue::$SCALAR(Some(10 as $TYPE)),
+ (greater_op, less_op),
+ );
+ // l > r + 6 AND r > l - 7
+ let l_gt_r = 6 as $TYPE;
+ let r_gt_l = -7 as $TYPE;
+ $GENERATE_CASE_FUNC_NAME::<true>(
+ expr.clone(),
+ left_col.clone(),
+ right_col.clone(),
+ seed,
+ l_gt_r,
+ r_gt_l,
+ )?;
+ // Descending tests
+ // r < l - 6 AND l < r + 7
+ let r_lt_l = -l_gt_r;
+ let l_lt_r = -r_gt_l;
+ $GENERATE_CASE_FUNC_NAME::<false>(
+ expr, left_col, right_col, seed, l_lt_r, r_lt_l,
+ )
+ }
+ };
}
- #[rstest]
- #[test]
- fn case_5(
- #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 123, 4123)] seed: u64,
- ) -> Result<()> {
- let left_col = Arc::new(Column::new("left_watermark", 0));
- let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark - 10 > right_watermark - 5 AND left_watermark - 30 <
right_watermark - 3
-
- let expr = gen_conjunctive_numeric_expr(
- left_col.clone(),
- right_col.clone(),
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- Operator::Minus,
- 10,
- 5,
- 30,
- 3,
- (Operator::Gt, Operator::Lt),
- );
- // l > r + 5 AND r > l - 27
- let l_gt_r = 5;
- let r_gt_l = -27;
- generate_case::<true>(
- expr.clone(),
- left_col.clone(),
- right_col.clone(),
- seed,
- l_gt_r,
- r_gt_l,
- )?;
- // Descending tests
- // r < l - 5 AND l < r + 27
- let r_lt_l = -l_gt_r;
- let l_lt_r = -r_gt_l;
- generate_case::<false>(expr, left_col, right_col, seed, l_lt_r,
r_lt_l)?;
-
- Ok(())
+ integer_float_case_2!(case_2_i32, generate_case_i32, i32, Int32);
+ integer_float_case_2!(case_2_i64, generate_case_i64, i64, Int64);
+ integer_float_case_2!(case_2_f64, generate_case_f64, f64, Float64);
+ integer_float_case_2!(case_2_f32, generate_case_f32, f32, Float32);
+
+ macro_rules! integer_float_case_3 {
+ ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty,
$SCALAR:ident) => {
+ #[rstest]
+ #[test]
+ fn $TEST_FUNC_NAME(
+ #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215,
4123)]
+ seed: u64,
+ #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
+ #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
+ ) -> Result<()> {
+ let left_col = Arc::new(Column::new("left_watermark", 0));
+ let right_col = Arc::new(Column::new("right_watermark", 0));
+
+ // left_watermark - 1 > right_watermark + 5 AND left_watermark
- 3 < right_watermark + 10
+ let expr = gen_conjunctive_numerical_expr(
+ left_col.clone(),
+ right_col.clone(),
+ (
+ Operator::Minus,
+ Operator::Plus,
+ Operator::Minus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(1 as $TYPE)),
+ ScalarValue::$SCALAR(Some(5 as $TYPE)),
+ ScalarValue::$SCALAR(Some(3 as $TYPE)),
+ ScalarValue::$SCALAR(Some(10 as $TYPE)),
+ (greater_op, less_op),
+ );
+ // l > r + 6 AND r > l - 13
+ let l_gt_r = 6 as $TYPE;
+ let r_gt_l = -13 as $TYPE;
+ $GENERATE_CASE_FUNC_NAME::<true>(
+ expr.clone(),
+ left_col.clone(),
+ right_col.clone(),
+ seed,
+ l_gt_r,
+ r_gt_l,
+ )?;
+ // Descending tests
+ // r < l - 6 AND l < r + 13
+ let r_lt_l = -l_gt_r;
+ let l_lt_r = -r_gt_l;
+ $GENERATE_CASE_FUNC_NAME::<false>(
+ expr, left_col, right_col, seed, l_lt_r, r_lt_l,
+ )
+ }
+ };
}
- #[rstest]
- #[test]
- fn case_6(
- #[values(0, 1, 2, 123, 756, 63, 345, 6443, 12341, 142, 123, 8900)]
seed: u64,
- ) -> Result<()> {
- let left_col = Arc::new(Column::new("left_watermark", 0));
- let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark - 1 >= right_watermark + 5 AND left_watermark - 10
<= right_watermark + 3
-
- let expr = gen_conjunctive_numeric_expr(
- left_col.clone(),
- right_col.clone(),
- Operator::Minus,
- Operator::Plus,
- Operator::Minus,
- Operator::Plus,
- 1,
- 5,
- 10,
- 3,
- (Operator::GtEq, Operator::LtEq),
- );
- // l >= r + 6 AND r >= l - 13
- let l_gt_r = 6;
- let r_gt_l = -13;
-
- generate_case::<true>(
- expr.clone(),
- left_col.clone(),
- right_col.clone(),
- seed,
- l_gt_r,
- r_gt_l,
- )?;
- generate_case::<true>(expr, left_col, right_col, seed, l_gt_r,
r_gt_l)?;
-
- Ok(())
+ integer_float_case_3!(case_3_i32, generate_case_i32, i32, Int32);
+ integer_float_case_3!(case_3_i64, generate_case_i64, i64, Int64);
+ integer_float_case_3!(case_3_f64, generate_case_f64, f64, Float64);
+ integer_float_case_3!(case_3_f32, generate_case_f32, f32, Float32);
+
+ macro_rules! integer_float_case_4 {
+ ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty,
$SCALAR:ident) => {
+ #[rstest]
+ #[test]
+ fn $TEST_FUNC_NAME(
+ #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215,
4123)]
+ seed: u64,
+ #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
+ #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
+ ) -> Result<()> {
+ let left_col = Arc::new(Column::new("left_watermark", 0));
+ let right_col = Arc::new(Column::new("right_watermark", 0));
+
+ // left_watermark - 10 > right_watermark - 5 AND
left_watermark - 30 < right_watermark - 3
+ let expr = gen_conjunctive_numerical_expr(
+ left_col.clone(),
+ right_col.clone(),
+ (
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Plus,
+ ),
+ ScalarValue::$SCALAR(Some(10 as $TYPE)),
+ ScalarValue::$SCALAR(Some(5 as $TYPE)),
+ ScalarValue::$SCALAR(Some(3 as $TYPE)),
+ ScalarValue::$SCALAR(Some(10 as $TYPE)),
+ (greater_op, less_op),
+ );
+ // l > r + 5 AND r > l - 13
+ let l_gt_r = 5 as $TYPE;
+ let r_gt_l = -13 as $TYPE;
+ $GENERATE_CASE_FUNC_NAME::<true>(
+ expr.clone(),
+ left_col.clone(),
+ right_col.clone(),
+ seed,
+ l_gt_r,
+ r_gt_l,
+ )?;
+ // Descending tests
+ // r < l - 5 AND l < r + 13
+ let r_lt_l = -l_gt_r;
+ let l_lt_r = -r_gt_l;
+ $GENERATE_CASE_FUNC_NAME::<false>(
+ expr, left_col, right_col, seed, l_lt_r, r_lt_l,
+ )
+ }
+ };
}
- #[rstest]
- #[test]
- fn case_7(
- #[values(0, 1, 2, 123, 77, 93, 104, 624, 115, 613, 8365, 9345)] seed:
u64,
- ) -> Result<()> {
- let left_col = Arc::new(Column::new("left_watermark", 0));
- let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark + 4 >= right_watermark + 5 AND left_watermark - 20 <
right_watermark - 5
-
- let expr = gen_conjunctive_numeric_expr(
- left_col.clone(),
- right_col.clone(),
- Operator::Plus,
- Operator::Plus,
- Operator::Minus,
- Operator::Minus,
- 4,
- 5,
- 20,
- 5,
- (Operator::GtEq, Operator::Lt),
- );
- // l >= r + 1 AND r > l - 15
- let l_gt_r = 1;
- let r_gt_l = -15;
- generate_case::<true>(
- expr.clone(),
- left_col.clone(),
- right_col.clone(),
- seed,
- l_gt_r,
- r_gt_l,
- )?;
- // Descending tests
- // r < l - 5 AND l < r + 27
- let r_lt_l = -l_gt_r;
- let l_lt_r = -r_gt_l;
- generate_case::<false>(expr, left_col, right_col, seed, l_lt_r,
r_lt_l)?;
-
- Ok(())
+ integer_float_case_4!(case_4_i32, generate_case_i32, i32, Int32);
+ integer_float_case_4!(case_4_i64, generate_case_i64, i64, Int64);
+ integer_float_case_4!(case_4_f64, generate_case_f64, f64, Float64);
+ integer_float_case_4!(case_4_f32, generate_case_f32, f32, Float32);
+
+ macro_rules! integer_float_case_5 {
+ ($TEST_FUNC_NAME:ident, $GENERATE_CASE_FUNC_NAME:ident, $TYPE:ty,
$SCALAR:ident) => {
+ #[rstest]
+ #[test]
+ fn $TEST_FUNC_NAME(
+ #[values(0, 1, 2, 3, 4, 12, 32, 314, 3124, 123, 125, 211, 215,
4123)]
+ seed: u64,
+ #[values(Operator::Gt, Operator::GtEq)] greater_op: Operator,
+ #[values(Operator::Lt, Operator::LtEq)] less_op: Operator,
+ ) -> Result<()> {
+ let left_col = Arc::new(Column::new("left_watermark", 0));
+ let right_col = Arc::new(Column::new("right_watermark", 0));
+
+ // left_watermark - 10 > right_watermark - 5 AND
left_watermark - 30 < right_watermark - 3
+ let expr = gen_conjunctive_numerical_expr(
+ left_col.clone(),
+ right_col.clone(),
+ (
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ Operator::Minus,
+ ),
+ ScalarValue::$SCALAR(Some(10 as $TYPE)),
+ ScalarValue::$SCALAR(Some(5 as $TYPE)),
+ ScalarValue::$SCALAR(Some(30 as $TYPE)),
+ ScalarValue::$SCALAR(Some(3 as $TYPE)),
+ (greater_op, less_op),
+ );
+ // l > r + 5 AND r > l - 27
+ let l_gt_r = 5 as $TYPE;
+ let r_gt_l = -27 as $TYPE;
+ $GENERATE_CASE_FUNC_NAME::<true>(
+ expr.clone(),
+ left_col.clone(),
+ right_col.clone(),
+ seed,
+ l_gt_r,
+ r_gt_l,
+ )?;
+ // Descending tests
+ // r < l - 5 AND l < r + 27
+ let r_lt_l = -l_gt_r;
+ let l_lt_r = -r_gt_l;
+ $GENERATE_CASE_FUNC_NAME::<false>(
+ expr, left_col, right_col, seed, l_lt_r, r_lt_l,
+ )
+ }
+ };
}
- #[rstest]
- #[test]
- fn case_8(
- #[values(0, 1, 2, 24, 53, 412, 364, 345, 737, 1010, 52, 1554)] seed:
u64,
- ) -> Result<()> {
- let left_col = Arc::new(Column::new("left_watermark", 0));
- let right_col = Arc::new(Column::new("right_watermark", 0));
- // left_watermark + 4 >= right_watermark + 5 AND left_watermark - 20 <
right_watermark - 5
-
- let expr = gen_conjunctive_numeric_expr(
- left_col.clone(),
- right_col.clone(),
- Operator::Plus,
- Operator::Plus,
- Operator::Minus,
- Operator::Minus,
- 4,
- 5,
- 20,
- 5,
- (Operator::Gt, Operator::LtEq),
- );
- // l >= r + 1 AND r > l - 15
- let l_gt_r = 1;
- let r_gt_l = -15;
- generate_case::<true>(
- expr.clone(),
- left_col.clone(),
- right_col.clone(),
- seed,
- l_gt_r,
- r_gt_l,
- )?;
- // Descending tests
- // r < l - 5 AND l < r + 27
- let r_lt_l = -l_gt_r;
- let l_lt_r = -r_gt_l;
- generate_case::<false>(expr, left_col, right_col, seed, l_lt_r,
r_lt_l)?;
-
- Ok(())
- }
+ integer_float_case_5!(case_5_i32, generate_case_i32, i32, Int32);
+ integer_float_case_5!(case_5_i64, generate_case_i64, i64, Int64);
+ integer_float_case_5!(case_5_f64, generate_case_f64, f64, Float64);
+ integer_float_case_5!(case_5_f32, generate_case_f32, f32, Float32);
#[test]
fn test_gather_node_indices_dont_remove() -> Result<()> {
@@ -1067,6 +997,7 @@ mod tests {
assert_eq!(prev_node_count, final_node_count);
Ok(())
}
+
#[test]
fn test_gather_node_indices_remove() -> Result<()> {
// Expression: a@0 + b@1 + 1 > y@0 - z@1, given a@0 + b@1.
diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
index 9a4b0bfe8d..6c2d1b0f41 100644
--- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
+++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
@@ -28,6 +28,7 @@ use datafusion_expr::type_coercion::binary::coerce_types;
use datafusion_expr::Operator;
use crate::aggregate::min_max::{max, min};
+use crate::intervals::rounding::alter_fp_rounding_mode;
/// This type represents a single endpoint of an [`Interval`]. An endpoint can
/// be open or closed, denoting whether the interval includes or excludes the
@@ -75,38 +76,54 @@ impl IntervalBound {
/// The result is unbounded if either is; otherwise, their values are
/// added. The result is closed if both original bounds are closed, or open
/// otherwise.
- pub fn add<T: Borrow<IntervalBound>>(&self, other: T) ->
Result<IntervalBound> {
+ pub fn add<const UPPER: bool, T: Borrow<IntervalBound>>(
+ &self,
+ other: T,
+ ) -> Result<IntervalBound> {
let rhs = other.borrow();
if self.is_unbounded() || rhs.is_unbounded() {
- IntervalBound::make_unbounded(coerce_types(
+ return IntervalBound::make_unbounded(coerce_types(
&self.get_datatype(),
&Operator::Plus,
&rhs.get_datatype(),
- )?)
- } else {
- self.value
- .add(&rhs.value)
- .map(|v| IntervalBound::new(v, self.open || rhs.open))
+ )?);
}
+ match self.get_datatype() {
+ DataType::Float64 | DataType::Float32 => {
+ alter_fp_rounding_mode::<UPPER, _>(&self.value, &rhs.value,
|lhs, rhs| {
+ lhs.add(rhs)
+ })
+ }
+ _ => self.value.add(&rhs.value),
+ }
+ .map(|v| IntervalBound::new(v, self.open || rhs.open))
}
/// This function subtracts the given `IntervalBound` from `self`.
/// The result is unbounded if either is; otherwise, their values are
/// subtracted. The result is closed if both original bounds are closed,
/// or open otherwise.
- pub fn sub<T: Borrow<IntervalBound>>(&self, other: T) ->
Result<IntervalBound> {
+ pub fn sub<const UPPER: bool, T: Borrow<IntervalBound>>(
+ &self,
+ other: T,
+ ) -> Result<IntervalBound> {
let rhs = other.borrow();
if self.is_unbounded() || rhs.is_unbounded() {
- IntervalBound::make_unbounded(coerce_types(
+ return IntervalBound::make_unbounded(coerce_types(
&self.get_datatype(),
&Operator::Minus,
&rhs.get_datatype(),
- )?)
- } else {
- self.value
- .sub(&rhs.value)
- .map(|v| IntervalBound::new(v, self.open || rhs.open))
+ )?);
+ }
+ match self.get_datatype() {
+ DataType::Float64 | DataType::Float32 => {
+ alter_fp_rounding_mode::<UPPER, _>(&self.value, &rhs.value,
|lhs, rhs| {
+ lhs.sub(rhs)
+ })
+ }
+ _ => self.value.sub(&rhs.value),
}
+ .map(|v| IntervalBound::new(v, self.open || rhs.open))
}
/// This function chooses one of the given `IntervalBound`s according to
@@ -404,8 +421,8 @@ impl Interval {
pub fn add<T: Borrow<Interval>>(&self, other: T) -> Result<Interval> {
let rhs = other.borrow();
Ok(Interval::new(
- self.lower.add(&rhs.lower)?,
- self.upper.add(&rhs.upper)?,
+ self.lower.add::<false, _>(&rhs.lower)?,
+ self.upper.add::<true, _>(&rhs.upper)?,
))
}
@@ -416,8 +433,8 @@ impl Interval {
pub fn sub<T: Borrow<Interval>>(&self, other: T) -> Result<Interval> {
let rhs = other.borrow();
Ok(Interval::new(
- self.lower.sub(&rhs.upper)?,
- self.upper.sub(&rhs.lower)?,
+ self.lower.sub::<false, _>(&rhs.upper)?,
+ self.upper.sub::<true, _>(&rhs.lower)?,
))
}
@@ -463,6 +480,8 @@ pub fn is_datatype_supported(data_type: &DataType) -> bool {
| &DataType::UInt32
| &DataType::UInt16
| &DataType::UInt8
+ | &DataType::Float64
+ | &DataType::Float32
)
}
@@ -1041,7 +1060,7 @@ mod tests {
// This function tests if valid constructions produce standardized objects
// ([false, false], [false, true], [true, true]) for boolean intervals.
#[test]
- fn non_standard_interval_constructs() -> Result<()> {
+ fn non_standard_interval_constructs() {
let cases = vec![
(
IntervalBound::new(Boolean(None), true),
@@ -1078,6 +1097,80 @@ mod tests {
for case in cases {
assert_eq!(Interval::new(case.0, case.1), case.2)
}
- Ok(())
+ }
+
+ macro_rules! capture_mode_change {
+ ($TYPE:ty) => {
+ paste::item! {
+ capture_mode_change_helper!([<capture_mode_change_ $TYPE>],
+ [<create_interval_ $TYPE>],
+ $TYPE);
+ }
+ };
+ }
+
+ macro_rules! capture_mode_change_helper {
+ ($TEST_FN_NAME:ident, $CREATE_FN_NAME:ident, $TYPE:ty) => {
+ fn $CREATE_FN_NAME(lower: $TYPE, upper: $TYPE) -> Interval {
+ Interval::make(Some(lower as $TYPE), Some(upper as $TYPE),
(true, true))
+ }
+
+ fn $TEST_FN_NAME(input: ($TYPE, $TYPE), expect_low: bool,
expect_high: bool) {
+ assert!(expect_low || expect_high);
+ let interval1 = $CREATE_FN_NAME(input.0, input.0);
+ let interval2 = $CREATE_FN_NAME(input.1, input.1);
+ let result = interval1.add(&interval2).unwrap();
+ let without_fe = $CREATE_FN_NAME(input.0 + input.1, input.0 +
input.1);
+ assert!(
+ (!expect_low || result.lower.value <
without_fe.lower.value)
+ && (!expect_high || result.upper.value >
without_fe.upper.value)
+ );
+ }
+ };
+ }
+
+ capture_mode_change!(f32);
+ capture_mode_change!(f64);
+
+ #[cfg(all(
+ any(target_arch = "x86_64", target_arch = "aarch64"),
+ not(target_os = "windows")
+ ))]
+ #[test]
+ fn test_add_intervals_lower_affected_f32() {
+ // Lower is affected
+ let lower = f32::from_bits(1073741887);
//1000000000000000000000000111111
+ let upper = f32::from_bits(1098907651);
//1000001100000000000000000000011
+ capture_mode_change_f32((lower, upper), true, false);
+
+ // Upper is affected
+ let lower = f32::from_bits(1072693248);
//111111111100000000000000000000
+ let upper = f32::from_bits(715827883); //101010101010101010101010101011
+ capture_mode_change_f32((lower, upper), false, true);
+
+ // Lower is affected
+ let lower = 1.0; // 0x3FF0000000000000
+ let upper = 0.3; // 0x3FD3333333333333
+ capture_mode_change_f64((lower, upper), true, false);
+
+ // Upper is affected
+ let lower = 1.4999999999999998; // 0x3FF7FFFFFFFFFFFF
+ let upper = 0.000_000_000_000_000_022_044_604_925_031_31; //
0x3C796A6B413BB21F
+ capture_mode_change_f64((lower, upper), false, true);
+ }
+
+ #[cfg(any(
+ not(any(target_arch = "x86_64", target_arch = "aarch64")),
+ target_os = "windows"
+ ))]
+ #[test]
+ fn test_next_impl_add_intervals_f64() {
+ let lower = 1.5;
+ let upper = 1.5;
+ capture_mode_change_f64((lower, upper), true, true);
+
+ let lower = 1.5;
+ let upper = 1.5;
+ capture_mode_change_f32((lower, upper), true, true);
}
}
diff --git a/datafusion/physical-expr/src/intervals/mod.rs
b/datafusion/physical-expr/src/intervals/mod.rs
index 9883ba15b2..a9255752fe 100644
--- a/datafusion/physical-expr/src/intervals/mod.rs
+++ b/datafusion/physical-expr/src/intervals/mod.rs
@@ -15,11 +15,11 @@
// specific language governing permissions and limitations
// under the License.
-//! Interval calculations
-//!
+//! Interval arithmetic and constraint propagation library
pub mod cp_solver;
pub mod interval_aritmetic;
+pub mod rounding;
pub mod test_utils;
pub use cp_solver::{check_support, ExprIntervalGraph};
diff --git a/datafusion/physical-expr/src/intervals/rounding.rs
b/datafusion/physical-expr/src/intervals/rounding.rs
new file mode 100644
index 0000000000..06c4f9e8a9
--- /dev/null
+++ b/datafusion/physical-expr/src/intervals/rounding.rs
@@ -0,0 +1,401 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Floating point rounding mode utility library
+//! TODO: Remove this custom implementation and the "libc" dependency when
+//! floating-point rounding mode manipulation functions become available
+//! in Rust.
+
+use std::ops::{Add, BitAnd, Sub};
+
+use datafusion_common::Result;
+use datafusion_common::ScalarValue;
+
+// Define constants for ARM
+#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
+const FE_UPWARD: i32 = 0x00400000;
+#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
+const FE_DOWNWARD: i32 = 0x00800000;
+
+// Define constants for x86_64
+#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
+const FE_UPWARD: i32 = 0x0800;
+#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
+const FE_DOWNWARD: i32 = 0x0400;
+
+#[cfg(all(
+ any(target_arch = "x86_64", target_arch = "aarch64"),
+ not(target_os = "windows")
+))]
+extern crate libc;
+
+#[cfg(all(
+ any(target_arch = "x86_64", target_arch = "aarch64"),
+ not(target_os = "windows")
+))]
+extern "C" {
+ fn fesetround(round: i32);
+ fn fegetround() -> i32;
+}
+
+/// A trait to manipulate floating-point types with bitwise operations.
+/// Provides functions to convert a floating-point value to/from its bitwise
+/// representation as well as utility methods to handle special values.
+pub trait FloatBits {
+ /// The integer type used for bitwise operations.
+ type Item: Copy
+ + PartialEq
+ + BitAnd<Output = Self::Item>
+ + Add<Output = Self::Item>
+ + Sub<Output = Self::Item>;
+
+ /// The smallest positive floating-point value representable by this type.
+ const TINY_BITS: Self::Item;
+
+ /// The smallest (in magnitude) negative floating-point value
representable by this type.
+ const NEG_TINY_BITS: Self::Item;
+
+ /// A mask to clear the sign bit of the floating-point value's bitwise
representation.
+ const CLEAR_SIGN_MASK: Self::Item;
+
+ /// The integer value 1, used in bitwise operations.
+ const ONE: Self::Item;
+
+ /// The integer value 0, used in bitwise operations.
+ const ZERO: Self::Item;
+
+ /// Converts the floating-point value to its bitwise representation.
+ fn to_bits(self) -> Self::Item;
+
+ /// Converts the bitwise representation to the corresponding
floating-point value.
+ fn from_bits(bits: Self::Item) -> Self;
+
+ /// Returns true if the floating-point value is NaN (not a number).
+ fn float_is_nan(self) -> bool;
+
+ /// Returns the positive infinity value for this floating-point type.
+ fn infinity() -> Self;
+
+ /// Returns the negative infinity value for this floating-point type.
+ fn neg_infinity() -> Self;
+}
+
+impl FloatBits for f32 {
+ type Item = u32;
+ const TINY_BITS: u32 = 0x1; // Smallest positive f32.
+ const NEG_TINY_BITS: u32 = 0x8000_0001; // Smallest (in magnitude)
negative f32.
+ const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff;
+ const ONE: Self::Item = 1;
+ const ZERO: Self::Item = 0;
+
+ fn to_bits(self) -> Self::Item {
+ self.to_bits()
+ }
+
+ fn from_bits(bits: Self::Item) -> Self {
+ f32::from_bits(bits)
+ }
+
+ fn float_is_nan(self) -> bool {
+ self.is_nan()
+ }
+
+ fn infinity() -> Self {
+ f32::INFINITY
+ }
+
+ fn neg_infinity() -> Self {
+ f32::NEG_INFINITY
+ }
+}
+
+impl FloatBits for f64 {
+ type Item = u64;
+ const TINY_BITS: u64 = 0x1; // Smallest positive f64.
+ const NEG_TINY_BITS: u64 = 0x8000_0000_0000_0001; // Smallest (in
magnitude) negative f64.
+ const CLEAR_SIGN_MASK: u64 = 0x7fff_ffff_ffff_ffff;
+ const ONE: Self::Item = 1;
+ const ZERO: Self::Item = 0;
+
+ fn to_bits(self) -> Self::Item {
+ self.to_bits()
+ }
+
+ fn from_bits(bits: Self::Item) -> Self {
+ f64::from_bits(bits)
+ }
+
+ fn float_is_nan(self) -> bool {
+ self.is_nan()
+ }
+
+ fn infinity() -> Self {
+ f64::INFINITY
+ }
+
+ fn neg_infinity() -> Self {
+ f64::NEG_INFINITY
+ }
+}
+
+/// Returns the next representable floating-point value greater than the input
value.
+///
+/// This function takes a floating-point value that implements the FloatBits
trait,
+/// calculates the next representable value greater than the input, and
returns it.
+///
+/// If the input value is NaN or positive infinity, the function returns the
input value.
+///
+/// # Examples
+///
+/// ```
+/// use datafusion_physical_expr::intervals::rounding::next_up;
+///
+/// let f: f32 = 1.0;
+/// let next_f = next_up(f);
+/// assert_eq!(next_f, 1.0000001);
+/// ```
+#[allow(dead_code)]
+pub fn next_up<F: FloatBits + Copy>(float: F) -> F {
+ let bits = float.to_bits();
+ if float.float_is_nan() || bits == F::infinity().to_bits() {
+ return float;
+ }
+
+ let abs = bits & F::CLEAR_SIGN_MASK;
+ let next_bits = if abs == F::ZERO {
+ F::TINY_BITS
+ } else if bits == abs {
+ bits + F::ONE
+ } else {
+ bits - F::ONE
+ };
+ F::from_bits(next_bits)
+}
+
+/// Returns the next representable floating-point value smaller than the input
value.
+///
+/// This function takes a floating-point value that implements the FloatBits
trait,
+/// calculates the next representable value smaller than the input, and
returns it.
+///
+/// If the input value is NaN or negative infinity, the function returns the
input value.
+///
+/// # Examples
+///
+/// ```
+/// use datafusion_physical_expr::intervals::rounding::next_down;
+///
+/// let f: f32 = 1.0;
+/// let next_f = next_down(f);
+/// assert_eq!(next_f, 0.99999994);
+/// ```
+#[allow(dead_code)]
+pub fn next_down<F: FloatBits + Copy>(float: F) -> F {
+ let bits = float.to_bits();
+ if float.float_is_nan() || bits == F::neg_infinity().to_bits() {
+ return float;
+ }
+ let abs = bits & F::CLEAR_SIGN_MASK;
+ let next_bits = if abs == F::ZERO {
+ F::NEG_TINY_BITS
+ } else if bits == abs {
+ bits - F::ONE
+ } else {
+ bits + F::ONE
+ };
+ F::from_bits(next_bits)
+}
+
+#[cfg(any(
+ not(any(target_arch = "x86_64", target_arch = "aarch64")),
+ target_os = "windows"
+))]
+fn alter_fp_rounding_mode_conservative<const UPPER: bool, F>(
+ lhs: &ScalarValue,
+ rhs: &ScalarValue,
+ operation: F,
+) -> Result<ScalarValue>
+where
+ F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
+{
+ let mut result = operation(lhs, rhs)?;
+ match &mut result {
+ ScalarValue::Float64(Some(value)) => {
+ if UPPER {
+ *value = next_up(*value)
+ } else {
+ *value = next_down(*value)
+ }
+ }
+ ScalarValue::Float32(Some(value)) => {
+ if UPPER {
+ *value = next_up(*value)
+ } else {
+ *value = next_down(*value)
+ }
+ }
+ _ => {}
+ };
+ Ok(result)
+}
+
+pub fn alter_fp_rounding_mode<const UPPER: bool, F>(
+ lhs: &ScalarValue,
+ rhs: &ScalarValue,
+ operation: F,
+) -> Result<ScalarValue>
+where
+ F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
+{
+ #[cfg(all(
+ any(target_arch = "x86_64", target_arch = "aarch64"),
+ not(target_os = "windows")
+ ))]
+ unsafe {
+ let current = fegetround();
+ fesetround(if UPPER { FE_UPWARD } else { FE_DOWNWARD });
+ let result = operation(lhs, rhs);
+ fesetround(current);
+ result
+ }
+ #[cfg(any(
+ not(any(target_arch = "x86_64", target_arch = "aarch64")),
+ target_os = "windows"
+ ))]
+ alter_fp_rounding_mode_conservative::<UPPER, _>(lhs, rhs, operation)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{next_down, next_up};
+
+ #[test]
+ fn test_next_down() {
+ let x = 1.0f64;
+ // Clamp value into range [0, 1).
+ let clamped = x.clamp(0.0, next_down(1.0f64));
+ assert!(clamped < 1.0);
+ assert_eq!(next_up(clamped), 1.0);
+ }
+
+ #[test]
+ fn test_next_up_small_positive() {
+ let value: f64 = 1.0;
+ let result = next_up(value);
+ assert_eq!(result, 1.0000000000000002);
+ }
+
+ #[test]
+ fn test_next_up_small_negative() {
+ let value: f64 = -1.0;
+ let result = next_up(value);
+ assert_eq!(result, -0.9999999999999999);
+ }
+
+ #[test]
+ fn test_next_up_pos_infinity() {
+ let value: f64 = f64::INFINITY;
+ let result = next_up(value);
+ assert_eq!(result, f64::INFINITY);
+ }
+
+ #[test]
+ fn test_next_up_nan() {
+ let value: f64 = f64::NAN;
+ let result = next_up(value);
+ assert!(result.is_nan());
+ }
+
+ #[test]
+ fn test_next_down_small_positive() {
+ let value: f64 = 1.0;
+ let result = next_down(value);
+ assert_eq!(result, 0.9999999999999999);
+ }
+
+ #[test]
+ fn test_next_down_small_negative() {
+ let value: f64 = -1.0;
+ let result = next_down(value);
+ assert_eq!(result, -1.0000000000000002);
+ }
+
+ #[test]
+ fn test_next_down_neg_infinity() {
+ let value: f64 = f64::NEG_INFINITY;
+ let result = next_down(value);
+ assert_eq!(result, f64::NEG_INFINITY);
+ }
+
+ #[test]
+ fn test_next_down_nan() {
+ let value: f64 = f64::NAN;
+ let result = next_down(value);
+ assert!(result.is_nan());
+ }
+
+ #[test]
+ fn test_next_up_small_positive_f32() {
+ let value: f32 = 1.0;
+ let result = next_up(value);
+ assert_eq!(result, 1.0000001);
+ }
+
+ #[test]
+ fn test_next_up_small_negative_f32() {
+ let value: f32 = -1.0;
+ let result = next_up(value);
+ assert_eq!(result, -0.99999994);
+ }
+
+ #[test]
+ fn test_next_up_pos_infinity_f32() {
+ let value: f32 = f32::INFINITY;
+ let result = next_up(value);
+ assert_eq!(result, f32::INFINITY);
+ }
+
+ #[test]
+ fn test_next_up_nan_f32() {
+ let value: f32 = f32::NAN;
+ let result = next_up(value);
+ assert!(result.is_nan());
+ }
+ #[test]
+ fn test_next_down_small_positive_f32() {
+ let value: f32 = 1.0;
+ let result = next_down(value);
+ assert_eq!(result, 0.99999994);
+ }
+ #[test]
+ fn test_next_down_small_negative_f32() {
+ let value: f32 = -1.0;
+ let result = next_down(value);
+ assert_eq!(result, -1.0000001);
+ }
+ #[test]
+ fn test_next_down_neg_infinity_f32() {
+ let value: f32 = f32::NEG_INFINITY;
+ let result = next_down(value);
+ assert_eq!(result, f32::NEG_INFINITY);
+ }
+ #[test]
+ fn test_next_down_nan_f32() {
+ let value: f32 = f32::NAN;
+ let result = next_down(value);
+ assert!(result.is_nan());
+ }
+}
diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs
b/datafusion/physical-expr/src/intervals/test_utils.rs
index f233b246a7..6bbf74dc7d 100644
--- a/datafusion/physical-expr/src/intervals/test_utils.rs
+++ b/datafusion/physical-expr/src/intervals/test_utils.rs
@@ -29,40 +29,31 @@ use datafusion_expr::Operator;
/// This test function generates a conjunctive statement with two numeric
/// terms with the following form:
/// left_col (op_1) a >/>= right_col (op_2) b AND left_col (op_3) c </<=
right_col (op_4) d
-pub fn gen_conjunctive_numeric_expr(
+pub fn gen_conjunctive_numerical_expr(
left_col: Arc<dyn PhysicalExpr>,
right_col: Arc<dyn PhysicalExpr>,
- op_1: Operator,
- op_2: Operator,
- op_3: Operator,
- op_4: Operator,
- a: i32,
- b: i32,
- c: i32,
- d: i32,
+ op: (Operator, Operator, Operator, Operator),
+ a: ScalarValue,
+ b: ScalarValue,
+ c: ScalarValue,
+ d: ScalarValue,
bounds: (Operator, Operator),
) -> Arc<dyn PhysicalExpr> {
+ let (op_1, op_2, op_3, op_4) = op;
let left_and_1 = Arc::new(BinaryExpr::new(
left_col.clone(),
op_1,
- Arc::new(Literal::new(ScalarValue::Int32(Some(a)))),
+ Arc::new(Literal::new(a)),
));
let left_and_2 = Arc::new(BinaryExpr::new(
right_col.clone(),
op_2,
- Arc::new(Literal::new(ScalarValue::Int32(Some(b)))),
- ));
-
- let right_and_1 = Arc::new(BinaryExpr::new(
- left_col,
- op_3,
- Arc::new(Literal::new(ScalarValue::Int32(Some(c)))),
- ));
- let right_and_2 = Arc::new(BinaryExpr::new(
- right_col,
- op_4,
- Arc::new(Literal::new(ScalarValue::Int32(Some(d)))),
+ Arc::new(Literal::new(b)),
));
+ let right_and_1 =
+ Arc::new(BinaryExpr::new(left_col, op_3, Arc::new(Literal::new(c))));
+ let right_and_2 =
+ Arc::new(BinaryExpr::new(right_col, op_4, Arc::new(Literal::new(d))));
let (greater_op, less_op) = bounds;
let left_expr = Arc::new(BinaryExpr::new(left_and_1, greater_op,
left_and_2));