This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new b7ed06dff2 Refactor AnalysisContext and statistics() of FilterExec
(#6982)
b7ed06dff2 is described below
commit b7ed06dff23fd987a9b29439886a34c0b4e2b342
Author: Berkay Şahin <[email protected]>
AuthorDate: Thu Jul 20 16:42:21 2023 +0300
Refactor AnalysisContext and statistics() of FilterExec (#6982)
* Min/Max in ExprBoundaries are replaced with Interval
* Minor fix
* min max values replaced by intervals, analyze() computations are done
with intervals
* Floating points are computed more wisely
* simplifications
* simplifications
* Floating point selectivity is calculated considering interval
cardinalities.
* Interval ranges are calculated with cp_solver lib
* Remove the float equality case due to diverse result of windows test
* Revert "Remove the float equality case due to diverse result of windows
test"
This reverts commit d10182a8364a8628fd83884b680d0169140cd528.
* No need to assign literal intervals
* clean clone
* Multi columns may be evaluated in some cases, cont'd
* First iteration of refactoring
1) Analyze is removed from the methods of PhysicalExpr.
2) Interval arithmetic is applied to AnalysisContext values.
3) Intervals of input columns are updated now.
* Linter fix
* Tests added, bugs fixed
* Refactor on code
* Refactor of functions, fix for win tests
* minor changes
* Comments enriched
* Simplify collect_statistics
* Comment added for float value's selectivity
* Remove code duplication
* Update datafusion/physical-expr/src/intervals/interval_aritmetic.rs
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
* Update datafusion/physical-expr/src/intervals/interval_aritmetic.rs
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
* Update datafusion/physical-expr/src/intervals/interval_aritmetic.rs
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
* Update datafusion/physical-expr/src/intervals/interval_aritmetic.rs
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
* Update datafusion/physical-expr/src/physical_expr.rs
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
* Update datafusion/physical-expr/src/physical_expr.rs
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
* Update datafusion/physical-expr/src/physical_expr.rs
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
* Update datafusion/physical-expr/src/intervals/interval_aritmetic.rs
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
* Comments added
* Minor improvements
* Clippy
* next_value func moves to interval module
* Reverts next_value(), adds test
* Final review
---------
Co-authored-by: Mustafa Akur <[email protected]>
Co-authored-by: metesynnada <[email protected]>
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
datafusion/core/src/physical_plan/filter.rs | 464 ++++++++++++++++----
datafusion/physical-expr/src/expressions/binary.rs | 474 +--------------------
datafusion/physical-expr/src/expressions/column.rs | 91 +---
datafusion/physical-expr/src/expressions/like.rs | 22 +-
.../physical-expr/src/expressions/literal.rs | 35 +-
.../physical-expr/src/intervals/cp_solver.rs | 19 +-
.../src/intervals/interval_aritmetic.rs | 299 ++++++++++++-
datafusion/physical-expr/src/lib.rs | 4 +-
datafusion/physical-expr/src/physical_expr.rs | 328 ++++++++------
9 files changed, 913 insertions(+), 823 deletions(-)
diff --git a/datafusion/core/src/physical_plan/filter.rs
b/datafusion/core/src/physical_plan/filter.rs
index 6cb2490ee4..bb45149026 100644
--- a/datafusion/core/src/physical_plan/filter.rs
+++ b/datafusion/core/src/physical_plan/filter.rs
@@ -31,7 +31,6 @@ use super::{
use crate::physical_plan::{
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
Column, DisplayFormatType, EquivalenceProperties, ExecutionPlan,
Partitioning,
- PhysicalExpr,
};
use arrow::compute::filter_record_batch;
@@ -42,8 +41,11 @@ use datafusion_common::{DataFusionError, Result};
use datafusion_execution::TaskContext;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::BinaryExpr;
+use datafusion_physical_expr::intervals::is_operator_supported;
+use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{
- split_conjunction, AnalysisContext, OrderingEquivalenceProperties,
+ analyze, split_conjunction, AnalysisContext, ExprBoundaries,
+ OrderingEquivalenceProperties, PhysicalExpr,
};
use futures::stream::{Stream, StreamExt};
@@ -187,52 +189,88 @@ impl ExecutionPlan for FilterExec {
/// The output statistics of a filtering operation can be estimated if the
/// predicate's selectivity value can be determined for the incoming data.
fn statistics(&self) -> Statistics {
+ let predicate = self.predicate();
+
+ if let Some(binary) = predicate.as_any().downcast_ref::<BinaryExpr>() {
+ let columns = collect_columns(predicate);
+ if !is_operator_supported(binary.op()) || columns.is_empty() {
+ return Statistics::default();
+ }
+ }
+
let input_stats = self.input.statistics();
+ let input_column_stats = match input_stats.column_statistics {
+ Some(stats) => stats,
+ None => return Statistics::default(),
+ };
+
let starter_ctx =
- AnalysisContext::from_statistics(self.input.schema().as_ref(),
&input_stats);
-
- let analysis_ctx = self.predicate.analyze(starter_ctx);
-
- match analysis_ctx.boundaries {
- Some(boundaries) => {
- // Build back the column level statistics from the boundaries
inside the
- // analysis context. It is possible that these are going to be
different
- // than the input statistics, especially when a comparison is
made inside
- // the predicate expression (e.g. `col1 > 100`).
- let column_statistics = analysis_ctx
- .column_boundaries
- .iter()
- .map(|boundary| match boundary {
- Some(boundary) => ColumnStatistics {
- min_value: Some(boundary.min_value.clone()),
- max_value: Some(boundary.max_value.clone()),
- ..Default::default()
- },
- None => ColumnStatistics::default(),
- })
- .collect();
-
- Statistics {
- num_rows:
input_stats.num_rows.zip(boundaries.selectivity).map(
- |(num_rows, selectivity)| {
- (num_rows as f64 * selectivity).ceil() as usize
- },
- ),
- total_byte_size: input_stats
- .total_byte_size
- .zip(boundaries.selectivity)
- .map(|(num_rows, selectivity)| {
- (num_rows as f64 * selectivity).ceil() as usize
- }),
- column_statistics: Some(column_statistics),
- ..Default::default()
- }
- }
- None => Statistics::default(),
+ AnalysisContext::from_statistics(&self.input.schema(),
&input_column_stats);
+
+ let analysis_ctx = match analyze(predicate, starter_ctx) {
+ Ok(ctx) => ctx,
+ Err(_) => return Statistics::default(),
+ };
+
+ let selectivity = analysis_ctx.selectivity.unwrap_or(1.0);
+
+ let num_rows = input_stats
+ .num_rows
+ .map(|num| (num as f64 * selectivity).ceil() as usize);
+ let total_byte_size = input_stats
+ .total_byte_size
+ .map(|size| (size as f64 * selectivity).ceil() as usize);
+
+ let column_statistics = if let Some(analysis_boundaries) =
analysis_ctx.boundaries
+ {
+ collect_new_statistics(input_column_stats, selectivity,
analysis_boundaries)
+ } else {
+ input_column_stats
+ };
+
+ Statistics {
+ num_rows,
+ total_byte_size,
+ column_statistics: Some(column_statistics),
+ is_exact: Default::default(),
}
}
}
+/// This function ensures that all bounds in the `ExprBoundaries` vector are
+/// converted to closed bounds. If a lower/upper bound is initially open, it
+/// is adjusted by using the next/previous value for its data type to convert
+/// it into a closed bound.
+fn collect_new_statistics(
+ input_column_stats: Vec<ColumnStatistics>,
+ selectivity: f64,
+ analysis_boundaries: Vec<ExprBoundaries>,
+) -> Vec<ColumnStatistics> {
+ let nonempty_columns = selectivity > 0.0;
+ analysis_boundaries
+ .into_iter()
+ .enumerate()
+ .map(
+ |(
+ idx,
+ ExprBoundaries {
+ interval,
+ distinct_count,
+ ..
+ },
+ )| {
+ let closed_interval = interval.close_bounds();
+ ColumnStatistics {
+ null_count: input_column_stats[idx].null_count,
+ max_value:
nonempty_columns.then_some(closed_interval.upper.value),
+ min_value:
nonempty_columns.then_some(closed_interval.lower.value),
+ distinct_count,
+ }
+ },
+ )
+ .collect()
+}
+
/// The FilterExec streams wraps the input iterator and applies the predicate
expression to
/// determine which rows to include in its output batches
struct FilterExecStream {
@@ -485,40 +523,6 @@ mod tests {
let statistics = filter.statistics();
assert_eq!(statistics.num_rows, Some(25));
assert_eq!(statistics.total_byte_size, Some(25 * bytes_per_row));
-
- Ok(())
- }
-
- #[tokio::test]
- async fn test_filter_statistics_column_level_basic_expr() -> Result<()> {
- // Table:
- // a: min=1, max=100
- let schema = Schema::new(vec![Field::new("a", DataType::Int32,
false)]);
- let input = Arc::new(StatisticsExec::new(
- Statistics {
- num_rows: Some(100),
- column_statistics: Some(vec![ColumnStatistics {
- min_value: Some(ScalarValue::Int32(Some(1))),
- max_value: Some(ScalarValue::Int32(Some(100))),
- ..Default::default()
- }]),
- ..Default::default()
- },
- schema.clone(),
- ));
-
- // a <= 25
- let predicate: Arc<dyn PhysicalExpr> =
- binary(col("a", &schema)?, Operator::LtEq, lit(25i32), &schema)?;
-
- // WHERE a <= 25
- let filter: Arc<dyn ExecutionPlan> =
- Arc::new(FilterExec::try_new(predicate, input)?);
-
- let statistics = filter.statistics();
-
- // a must be in [1, 25] range now!
- assert_eq!(statistics.num_rows, Some(25));
assert_eq!(
statistics.column_statistics,
Some(vec![ColumnStatistics {
@@ -623,7 +627,6 @@ mod tests {
binary(col("a", &schema)?, Operator::GtEq, lit(10i32), &schema)?,
b_gt_5,
)?);
-
let statistics = filter.statistics();
// On a uniform distribution, only fifteen rows will satisfy the
// filter that 'a' proposed (a >= 10 AND a <= 25) (15/100) and only
@@ -641,7 +644,7 @@ mod tests {
..Default::default()
},
ColumnStatistics {
- min_value: Some(ScalarValue::Int32(Some(45))),
+ min_value: Some(ScalarValue::Int32(Some(46))),
max_value: Some(ScalarValue::Int32(Some(50))),
..Default::default()
}
@@ -679,4 +682,309 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn test_filter_statistics_multiple_columns() -> Result<()> {
+ // Table:
+ // a: min=1, max=100
+ // b: min=1, max=3
+ // c: min=1000.0 max=1100.0
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ Field::new("c", DataType::Float32, false),
+ ]);
+ let input = Arc::new(StatisticsExec::new(
+ Statistics {
+ num_rows: Some(1000),
+ total_byte_size: Some(4000),
+ column_statistics: Some(vec![
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(100))),
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(3))),
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Float32(Some(1000.0))),
+ max_value: Some(ScalarValue::Float32(Some(1100.0))),
+ ..Default::default()
+ },
+ ]),
+ ..Default::default()
+ },
+ schema,
+ ));
+ // WHERE a<=53 AND (b=3 AND (c<=1075.0 AND a>b))
+ let predicate = Arc::new(BinaryExpr::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("a", 0)),
+ Operator::LtEq,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(53)))),
+ )),
+ Operator::And,
+ Arc::new(BinaryExpr::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("b", 1)),
+ Operator::Eq,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(3)))),
+ )),
+ Operator::And,
+ Arc::new(BinaryExpr::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("c", 2)),
+ Operator::LtEq,
+
Arc::new(Literal::new(ScalarValue::Float32(Some(1075.0)))),
+ )),
+ Operator::And,
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("a", 0)),
+ Operator::Gt,
+ Arc::new(Column::new("b", 1)),
+ )),
+ )),
+ )),
+ ));
+ let filter: Arc<dyn ExecutionPlan> =
+ Arc::new(FilterExec::try_new(predicate, input)?);
+ let statistics = filter.statistics();
+ // 0.5 (from a) * 0.333333... (from b) * 0.798387... (from c) ≈
0.1330...
+ // num_rows after ceil => 133.0... => 134
+ // total_byte_size after ceil => 532.0... => 533
+ assert_eq!(statistics.num_rows, Some(134));
+ assert_eq!(statistics.total_byte_size, Some(533));
+ let exp_col_stats = Some(vec![
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(4))),
+ max_value: Some(ScalarValue::Int32(Some(53))),
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(3))),
+ max_value: Some(ScalarValue::Int32(Some(3))),
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Float32(Some(1000.0))),
+ max_value: Some(ScalarValue::Float32(Some(1075.0))),
+ ..Default::default()
+ },
+ ]);
+ let _ = exp_col_stats
+ .unwrap()
+ .into_iter()
+ .zip(statistics.column_statistics.unwrap())
+ .map(|(expected, actual)| {
+ if actual
+ .min_value
+ .clone()
+ .unwrap()
+ .get_datatype()
+ .is_floating()
+ {
+ // Windows rounds arithmetic operation results differently
for floating point numbers.
+ // Therefore, we check if the actual values are in an
epsilon range.
+ let actual_min = actual.min_value.unwrap();
+ let actual_max = actual.max_value.unwrap();
+ let expected_min = expected.min_value.unwrap();
+ let expected_max = expected.max_value.unwrap();
+ let eps = ScalarValue::Float32(Some(1e-6));
+
+ assert!(actual_min.sub(&expected_min).unwrap() < eps);
+ assert!(actual_min.sub(&expected_min).unwrap() < eps);
+
+ assert!(actual_max.sub(&expected_max).unwrap() < eps);
+ assert!(actual_max.sub(&expected_max).unwrap() < eps);
+ } else {
+ assert_eq!(actual, expected);
+ }
+ });
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_filter_statistics_full_selective() -> Result<()> {
+ // Table:
+ // a: min=1, max=100
+ // b: min=1, max=3
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ ]);
+ let input = Arc::new(StatisticsExec::new(
+ Statistics {
+ num_rows: Some(1000),
+ total_byte_size: Some(4000),
+ column_statistics: Some(vec![
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(100))),
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(3))),
+ ..Default::default()
+ },
+ ]),
+ ..Default::default()
+ },
+ schema,
+ ));
+ // WHERE a<200 AND 1<=b
+ let predicate = Arc::new(BinaryExpr::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("a", 0)),
+ Operator::Lt,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(200)))),
+ )),
+ Operator::And,
+ Arc::new(BinaryExpr::new(
+ Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
+ Operator::LtEq,
+ Arc::new(Column::new("b", 1)),
+ )),
+ ));
+ // Since filter predicate passes all entries, statistics after filter
shouldn't change.
+ let expected = input.statistics().column_statistics;
+ let filter: Arc<dyn ExecutionPlan> =
+ Arc::new(FilterExec::try_new(predicate, input)?);
+ let statistics = filter.statistics();
+
+ assert_eq!(statistics.num_rows, Some(1000));
+ assert_eq!(statistics.total_byte_size, Some(4000));
+ assert_eq!(statistics.column_statistics, expected);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_filter_statistics_zero_selective() -> Result<()> {
+ // Table:
+ // a: min=1, max=100
+ // b: min=1, max=3
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ ]);
+ let input = Arc::new(StatisticsExec::new(
+ Statistics {
+ num_rows: Some(1000),
+ total_byte_size: Some(4000),
+ column_statistics: Some(vec![
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(100))),
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(3))),
+ ..Default::default()
+ },
+ ]),
+ ..Default::default()
+ },
+ schema,
+ ));
+ // WHERE a>200 AND 1<=b
+ let predicate = Arc::new(BinaryExpr::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("a", 0)),
+ Operator::Gt,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(200)))),
+ )),
+ Operator::And,
+ Arc::new(BinaryExpr::new(
+ Arc::new(Literal::new(ScalarValue::Int32(Some(1)))),
+ Operator::LtEq,
+ Arc::new(Column::new("b", 1)),
+ )),
+ ));
+ let filter: Arc<dyn ExecutionPlan> =
+ Arc::new(FilterExec::try_new(predicate, input)?);
+ let statistics = filter.statistics();
+
+ assert_eq!(statistics.num_rows, Some(0));
+ assert_eq!(statistics.total_byte_size, Some(0));
+ assert_eq!(
+ statistics.column_statistics,
+ Some(vec![
+ ColumnStatistics {
+ min_value: None,
+ max_value: None,
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: None,
+ max_value: None,
+ ..Default::default()
+ },
+ ])
+ );
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_filter_statistics_more_inputs() -> Result<()> {
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::Int32, false),
+ Field::new("b", DataType::Int32, false),
+ ]);
+ let input = Arc::new(StatisticsExec::new(
+ Statistics {
+ num_rows: Some(1000),
+ total_byte_size: Some(4000),
+ column_statistics: Some(vec![
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(100))),
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(100))),
+ ..Default::default()
+ },
+ ]),
+ ..Default::default()
+ },
+ schema,
+ ));
+ // WHERE a<50
+ let predicate = Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("a", 0)),
+ Operator::Lt,
+ Arc::new(Literal::new(ScalarValue::Int32(Some(50)))),
+ ));
+ let filter: Arc<dyn ExecutionPlan> =
+ Arc::new(FilterExec::try_new(predicate, input)?);
+ let statistics = filter.statistics();
+
+ assert_eq!(statistics.num_rows, Some(490));
+ assert_eq!(statistics.total_byte_size, Some(1960));
+ assert_eq!(
+ statistics.column_statistics,
+ Some(vec![
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(49))),
+ ..Default::default()
+ },
+ ColumnStatistics {
+ min_value: Some(ScalarValue::Int32(Some(1))),
+ max_value: Some(ScalarValue::Int32(Some(100))),
+ ..Default::default()
+ },
+ ])
+ );
+
+ Ok(())
+ }
}
diff --git a/datafusion/physical-expr/src/expressions/binary.rs
b/datafusion/physical-expr/src/expressions/binary.rs
index 9d2087a0c5..b453e8135d 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -49,13 +49,12 @@ use arrow::compute::kernels::comparison::{
eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar,
lt_dyn_utf8_scalar,
lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar,
};
+use arrow::compute::kernels::concat_elements::concat_elements_utf8;
use arrow::compute::{cast, CastOptions};
use arrow::datatypes::*;
+use arrow::record_batch::RecordBatch;
use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn};
-use arrow::compute::kernels::concat_elements::concat_elements_utf8;
-
-use datafusion_expr::type_coercion::{is_decimal, is_timestamp,
is_utf8_or_large_utf8};
use kernels::{
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn,
bitwise_or_dyn_scalar,
bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar,
bitwise_shift_right_dyn,
@@ -73,32 +72,29 @@ use kernels_arrow::{
subtract_decimal_dyn_scalar, subtract_dyn_decimal, subtract_dyn_temporal,
};
-use arrow::datatypes::{DataType, Schema, TimeUnit};
-use arrow::record_batch::RecordBatch;
-
use self::kernels_arrow::{
add_dyn_temporal_left_scalar, add_dyn_temporal_right_scalar,
subtract_dyn_temporal_left_scalar, subtract_dyn_temporal_right_scalar,
};
-use super::column::Column;
use crate::array_expressions::{array_append, array_concat, array_prepend};
use crate::expressions::cast_column;
use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::intervals::{apply_operator, Interval};
use crate::physical_expr::down_cast_any_ref;
-use crate::{analysis_expect, AnalysisContext, ExprBoundaries, PhysicalExpr};
-use datafusion_common::cast::as_boolean_array;
+use crate::PhysicalExpr;
+use datafusion_common::cast::as_boolean_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::type_coercion::binary::{
coercion_decimal_mathematics_type, get_result_type,
};
+use datafusion_expr::type_coercion::{is_decimal, is_timestamp,
is_utf8_or_large_utf8};
use datafusion_expr::{ColumnarValue, Operator};
/// Binary expression
-#[derive(Debug, Hash)]
+#[derive(Debug, Hash, Clone)]
pub struct BinaryExpr {
left: Arc<dyn PhysicalExpr>,
op: Operator,
@@ -761,55 +757,6 @@ impl PhysicalExpr for BinaryExpr {
)))
}
- /// Return the boundaries of this binary expression's result.
- fn analyze(&self, context: AnalysisContext) -> AnalysisContext {
- match &self.op {
- Operator::Eq
- | Operator::Gt
- | Operator::Lt
- | Operator::LtEq
- | Operator::GtEq => {
- // We currently only support comparison when we know at least
one of the sides are
- // a known value (a scalar). This includes predicates like a >
20 or 5 > a.
- let context = self.left.analyze(context);
- let left_boundaries =
- analysis_expect!(context, context.boundaries()).clone();
-
- let context = self.right.analyze(context);
- let right_boundaries =
- analysis_expect!(context, context.boundaries.clone());
-
- match (left_boundaries.reduce(), right_boundaries.reduce()) {
- (_, Some(right_value)) => {
- // We know the right side is a scalar, so we can use
the operator as is
- analyze_expr_scalar_comparison(
- context,
- &self.op,
- &self.left,
- right_value,
- )
- }
- (Some(left_value), _) => {
- // If not, we have to swap the operator and left/right
(since this means
- // left has to be a scalar).
- let swapped_op = analysis_expect!(context,
self.op.swap());
- analyze_expr_scalar_comparison(
- context,
- &swapped_op,
- &self.right,
- left_value,
- )
- }
- _ => {
- // Both sides are columns, so we give up.
- context.with_boundaries(None)
- }
- }
- }
- _ => context.with_boundaries(None),
- }
- }
-
fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
// Get children intervals:
let left_interval = children[0];
@@ -864,132 +811,6 @@ impl PartialEq<dyn Any> for BinaryExpr {
}
}
-// Analyze the comparison between an expression (on the left) and a scalar
value
-// (on the right). The new boundaries will indicate whether it is always true,
always
-// false, or unknown (with a probablistic selectivity value attached). This
operation
-// will also include the new upper/lower boundaries for the operand on the
left if
-// they can be determined.
-fn analyze_expr_scalar_comparison(
- context: AnalysisContext,
- op: &Operator,
- left: &Arc<dyn PhysicalExpr>,
- right: ScalarValue,
-) -> AnalysisContext {
- let left_bounds = analysis_expect!(context,
left.analyze(context.clone()).boundaries);
- let left_min = left_bounds.min_value;
- let left_max = left_bounds.max_value;
-
- // Direct selectivity is applicable when we can determine that this
comparison will
- // always be true or false (e.g. `x > 10` where the `x`'s min value is 11
or `a < 5`
- // where the `a`'s max value is 4).
- let (always_selects, never_selects) = match op {
- Operator::Lt => (right > left_max, right <= left_min),
- Operator::LtEq => (right >= left_max, right < left_min),
- Operator::Gt => (right < left_min, right >= left_max),
- Operator::GtEq => (right <= left_min, right > left_max),
- Operator::Eq => (
- // Since min/max can be artificial (e.g. the min or max value of a
column
- // might be under/over the real value), we can't assume if the
right equals
- // to any left.min / left.max values it is always going to be
selected. But
- // we can assume that if the range(left) doesn't overlap with
right, it is
- // never going to be selected.
- false,
- right < left_min || right > left_max,
- ),
- _ => unreachable!(),
- };
-
- // Both can not be true at the same time.
- assert!(!(always_selects && never_selects));
-
- let selectivity = match (always_selects, never_selects) {
- (true, _) => 1.0,
- (_, true) => 0.0,
- (false, false) => {
- // If there is a partial overlap, then we can estimate the
selectivity
- // by computing the ratio of the existing overlap to the total
range. Since we
- // currently don't have access to a value distribution histogram,
the part below
- // assumes a uniform distribution by default.
-
- // Our [min, max] is inclusive, so we need to add 1 to the
difference.
- let total_range = analysis_expect!(context,
left_max.distance(&left_min)) + 1;
- let overlap_between_boundaries = analysis_expect!(
- context,
- match op {
- Operator::Lt => right.distance(&left_min),
- Operator::Gt => left_max.distance(&right),
- Operator::LtEq => right.distance(&left_min).map(|dist|
dist + 1),
- Operator::GtEq => left_max.distance(&right).map(|dist|
dist + 1),
- Operator::Eq => Some(1),
- _ => None,
- }
- );
-
- overlap_between_boundaries as f64 / total_range as f64
- }
- };
-
- // The context represents all the knowledge we have gathered during the
- // analysis process, which we can now add more since the expression's upper
- // and lower boundaries might have changed.
- let context = match left.as_any().downcast_ref::<Column>() {
- Some(column_expr) => {
- let (left_min, left_max) = match op {
- // TODO: for lt/gt, we technically should shrink the
possibility space
- // by one since a < 5 means that 5 is not a possible value for
`a`. However,
- // it is currently tricky to do so (e.g. for floats, we can
get away with 4.999
- // so we need a smarter logic to find out what is the closest
value that is
- // different from the scalar_value).
- Operator::Lt | Operator::LtEq => {
- // We only want to update the upper bound when we know it
will help us (e.g.
- // it is actually smaller than what we have right now) and
it is a valid
- // value (e.g. [0, 100] < -100 would update the boundaries
to [0, -100] if
- // there weren't the selectivity check).
- if right < left_max && selectivity > 0.0 {
- (left_min, right)
- } else {
- (left_min, left_max)
- }
- }
- Operator::Gt | Operator::GtEq => {
- // Same as above, but this time we want to limit the lower
bound.
- if right > left_min && selectivity > 0.0 {
- (right, left_max)
- } else {
- (left_min, left_max)
- }
- }
- // For equality, we don't have the range problem so even if
the selectivity
- // is 0.0, we can still update the boundaries.
- Operator::Eq => (right.clone(), right),
- _ => unreachable!(),
- };
-
- let left_bounds =
- ExprBoundaries::new(left_min, left_max,
left_bounds.distinct_count);
- context.with_column_update(column_expr.index(), left_bounds)
- }
- None => context,
- };
-
- // The selectivity can't be be greater than 1.0.
- assert!(selectivity <= 1.0);
-
- let (pred_min, pred_max, pred_distinct) = match (always_selects,
never_selects) {
- (false, true) => (false, false, 1),
- (true, false) => (true, true, 1),
- _ => (false, true, 2),
- };
-
- let result_boundaries = Some(ExprBoundaries::new_with_selectivity(
- ScalarValue::Boolean(Some(pred_min)),
- ScalarValue::Boolean(Some(pred_max)),
- Some(pred_distinct),
- Some(selectivity),
- ));
- context.with_boundaries(result_boundaries)
-}
-
/// unwrap underlying (non dictionary) value, if any, to pass to a scalar
kernel
fn unwrap_dict_value(v: ScalarValue) -> ScalarValue {
if let ScalarValue::Dictionary(_key_type, v) = v {
@@ -1352,7 +1173,7 @@ mod tests {
ArrowNumericType, Decimal128Type, Field, Int32Type, SchemaRef,
};
use arrow_schema::ArrowError;
- use datafusion_common::{ColumnStatistics, Result, Statistics};
+ use datafusion_common::Result;
use datafusion_expr::type_coercion::binary::get_input_types;
/// Performs a binary operation, applying any type coercion necessary
@@ -4489,287 +4310,6 @@ mod tests {
Ok(())
}
- /// Return a pair of (schema, statistics) for a table with a single column
(called "a") with
- /// the same type as the `min_value`/`max_value`.
- fn get_test_table_stats(
- min_value: ScalarValue,
- max_value: ScalarValue,
- ) -> (Schema, Statistics) {
- assert_eq!(min_value.get_datatype(), max_value.get_datatype());
- let schema = Schema::new(vec![Field::new("a",
min_value.get_datatype(), false)]);
- let columns = vec![ColumnStatistics {
- min_value: Some(min_value),
- max_value: Some(max_value),
- null_count: None,
- distinct_count: None,
- }];
- let statistics = Statistics {
- column_statistics: Some(columns),
- ..Default::default()
- };
- (schema, statistics)
- }
-
- #[test]
- fn test_analyze_expr_scalar_comparison() -> Result<()> {
- // A table where the column 'a' has a min of 1, a max of 100.
- let (schema, statistics) =
- get_test_table_stats(ScalarValue::from(1i64),
ScalarValue::from(100i64));
-
- let cases = [
- // (operator, rhs), (expected selectivity, expected min, expected
max)
- //
-------------------------------------------------------------------
- //
- // Table:
- // - a (min = 1, max = 100, distinct_count = null)
- //
- // Equality (a = $):
- //
- ((Operator::Eq, 1), (1.0 / 100.0, 1, 1)),
- ((Operator::Eq, 5), (1.0 / 100.0, 5, 5)),
- ((Operator::Eq, 99), (1.0 / 100.0, 99, 99)),
- ((Operator::Eq, 100), (1.0 / 100.0, 100, 100)),
- // For never matches like the following, we still produce the
correct
- // min/max values since if this condition holds by an off chance,
then
- // the result of expression will effectively become the = $limit.
- ((Operator::Eq, 0), (0.0, 0, 0)),
- ((Operator::Eq, -101), (0.0, -101, -101)),
- ((Operator::Eq, 101), (0.0, 101, 101)),
- //
- // Less than (a < $):
- //
- // Note: upper bounds for less than is currently overstated (by
the closest value).
- // see the comment in `compare_left_boundaries` for the reason
- ((Operator::Lt, 5), (4.0 / 100.0, 1, 5)),
- ((Operator::Lt, 99), (98.0 / 100.0, 1, 99)),
- ((Operator::Lt, 101), (100.0 / 100.0, 1, 100)),
- // Unlike equality, we now have an obligation to provide a range
of values here
- // so if "col < -100" expr is executed, we don't want to say col
can take [0, -100].
- ((Operator::Lt, 0), (0.0, 1, 100)),
- ((Operator::Lt, 1), (0.0, 1, 100)),
- ((Operator::Lt, -100), (0.0, 1, 100)),
- ((Operator::Lt, -200), (0.0, 1, 100)),
- // We also don't want to expand the range unnecessarily even if
the predicate is
- // successful.
- ((Operator::Lt, 200), (1.0, 1, 100)),
- //
- // Less than or equal (a <= $):
- //
- ((Operator::LtEq, -100), (0.0, 1, 100)),
- ((Operator::LtEq, 0), (0.0, 1, 100)),
- ((Operator::LtEq, 1), (1.0 / 100.0, 1, 1)),
- ((Operator::LtEq, 5), (5.0 / 100.0, 1, 5)),
- ((Operator::LtEq, 99), (99.0 / 100.0, 1, 99)),
- ((Operator::LtEq, 100), (100.0 / 100.0, 1, 100)),
- ((Operator::LtEq, 101), (1.0, 1, 100)),
- ((Operator::LtEq, 200), (1.0, 1, 100)),
- //
- // Greater than (a > $):
- //
- ((Operator::Gt, -100), (1.0, 1, 100)),
- ((Operator::Gt, 0), (1.0, 1, 100)),
- ((Operator::Gt, 1), (99.0 / 100.0, 1, 100)),
- ((Operator::Gt, 5), (95.0 / 100.0, 5, 100)),
- ((Operator::Gt, 99), (1.0 / 100.0, 99, 100)),
- ((Operator::Gt, 100), (0.0, 1, 100)),
- ((Operator::Gt, 101), (0.0, 1, 100)),
- ((Operator::Gt, 200), (0.0, 1, 100)),
- //
- // Greater than or equal (a >= $):
- //
- ((Operator::GtEq, -100), (1.0, 1, 100)),
- ((Operator::GtEq, 0), (1.0, 1, 100)),
- ((Operator::GtEq, 1), (1.0, 1, 100)),
- ((Operator::GtEq, 5), (96.0 / 100.0, 5, 100)),
- ((Operator::GtEq, 99), (2.0 / 100.0, 99, 100)),
- ((Operator::GtEq, 100), (1.0 / 100.0, 100, 100)),
- ((Operator::GtEq, 101), (0.0, 1, 100)),
- ((Operator::GtEq, 200), (0.0, 1, 100)),
- ];
-
- for ((operator, rhs), (exp_selectivity, exp_min, exp_max)) in cases {
- let context = AnalysisContext::from_statistics(&schema,
&statistics);
- let left = col("a", &schema).unwrap();
- let right = ScalarValue::Int64(Some(rhs));
- let analysis_ctx =
- analyze_expr_scalar_comparison(context, &operator, &left,
right);
- let boundaries = analysis_ctx
- .boundaries
- .as_ref()
- .expect("Analysis must complete for this test!");
-
- assert_eq!(
- boundaries
- .selectivity
- .expect("compare_left_boundaries must produce a
selectivity value"),
- exp_selectivity
- );
-
- if exp_selectivity == 1.0 {
- // When the expected selectivity is 1.0, the resulting
expression
- // should always be true.
- assert_eq!(boundaries.reduce(),
Some(ScalarValue::Boolean(Some(true))));
- } else if exp_selectivity == 0.0 {
- // When the expected selectivity is 0.0, the resulting
expression
- // should always be false.
- assert_eq!(boundaries.reduce(),
Some(ScalarValue::Boolean(Some(false))));
- } else {
- // Otherwise, it should be [false, true] (since we don't know
anything for sure)
- assert_eq!(boundaries.min_value,
ScalarValue::Boolean(Some(false)));
- assert_eq!(boundaries.max_value,
ScalarValue::Boolean(Some(true)));
- }
-
- // For getting the updated boundaries, we can simply analyze the
LHS
- // with the existing context.
- let left_boundaries = left
- .analyze(analysis_ctx)
- .boundaries
- .expect("this case should not return None");
- assert_eq!(left_boundaries.min_value,
ScalarValue::Int64(Some(exp_min)));
- assert_eq!(left_boundaries.max_value,
ScalarValue::Int64(Some(exp_max)));
- }
- Ok(())
- }
-
- #[test]
- fn test_comparison_result_estimate_different_type() -> Result<()> {
- // A table where the column 'a' has a min of 1.3, a max of 50.7.
- let (schema, statistics) =
- get_test_table_stats(ScalarValue::from(1.3),
ScalarValue::from(50.7));
- let distance = 50.0; // rounded distance is (max - min) + 1
-
- // Since the generic version already covers all the paths, we can just
- // test a small subset of the cases.
- let cases = [
- // (operator, rhs), (expected selectivity, expected min, expected
max)
- //
-------------------------------------------------------------------
- //
- // Table:
- // - a (min = 1.3, max = 50.7, distinct_count = 25)
- //
- // Never selects (out of range)
- ((Operator::Eq, 1.1), (0.0, 1.1, 1.1)),
- ((Operator::Eq, 50.75), (0.0, 50.75, 50.75)),
- ((Operator::Lt, 1.3), (0.0, 1.3, 50.7)),
- ((Operator::LtEq, 1.29), (0.0, 1.3, 50.7)),
- ((Operator::Gt, 50.7), (0.0, 1.3, 50.7)),
- ((Operator::GtEq, 50.75), (0.0, 1.3, 50.7)),
- // Always selects
- ((Operator::Lt, 50.75), (1.0, 1.3, 50.7)),
- ((Operator::LtEq, 50.75), (1.0, 1.3, 50.7)),
- ((Operator::Gt, 1.29), (1.0, 1.3, 50.7)),
- ((Operator::GtEq, 1.3), (1.0, 1.3, 50.7)),
- // Partial selection (the x in 'x/distance' is basically the
rounded version of
- // the bound distance, as per the implementation).
- ((Operator::Eq, 27.8), (1.0 / distance, 27.8, 27.8)),
- ((Operator::Lt, 5.2), (4.0 / distance, 1.3, 5.2)), // On a uniform
distribution, this is {2.6, 3.9}
- ((Operator::LtEq, 1.3), (1.0 / distance, 1.3, 1.3)),
- ((Operator::Gt, 45.5), (5.0 / distance, 45.5, 50.7)), // On a
uniform distribution, this is {46.8, 48.1, 49.4}
- ((Operator::GtEq, 50.7), (1.0 / distance, 50.7, 50.7)),
- ];
-
- for ((operator, rhs), (exp_selectivity, exp_min, exp_max)) in cases {
- let context = AnalysisContext::from_statistics(&schema,
&statistics);
- let left = col("a", &schema).unwrap();
- let right = ScalarValue::from(rhs);
- let analysis_ctx =
- analyze_expr_scalar_comparison(context, &operator, &left,
right);
- let boundaries = analysis_ctx
- .clone()
- .boundaries
- .expect("Analysis must complete for this test!");
-
- assert_eq!(
- boundaries
- .selectivity
- .expect("compare_left_boundaries must produce a
selectivity value"),
- exp_selectivity
- );
-
- if exp_selectivity == 1.0 {
- // When the expected selectivity is 1.0, the resulting
expression
- // should always be true.
- assert_eq!(boundaries.reduce(), Some(ScalarValue::from(true)));
- } else if exp_selectivity == 0.0 {
- // When the expected selectivity is 0.0, the resulting
expression
- // should always be false.
- assert_eq!(boundaries.reduce(),
Some(ScalarValue::from(false)));
- } else {
- // Otherwise, it should be [false, true] (since we don't know
anything for sure)
- assert_eq!(boundaries.min_value, ScalarValue::from(false));
- assert_eq!(boundaries.max_value, ScalarValue::from(true));
- }
-
- let left_boundaries = left
- .analyze(analysis_ctx)
- .boundaries
- .expect("this case should not return None");
- assert_eq!(
- left_boundaries.min_value,
- ScalarValue::Float64(Some(exp_min))
- );
- assert_eq!(
- left_boundaries.max_value,
- ScalarValue::Float64(Some(exp_max))
- );
- }
- Ok(())
- }
-
- #[test]
- fn test_binary_expression_boundaries() -> Result<()> {
- // A table where the column 'a' has a min of 1, a max of 100.
- let (schema, statistics) =
- get_test_table_stats(ScalarValue::from(1), ScalarValue::from(100));
-
- // expression: "a >= 25"
- let a = col("a", &schema).unwrap();
- let gt = binary(
- a.clone(),
- Operator::GtEq,
- lit(ScalarValue::from(25)),
- &schema,
- )?;
-
- let context = AnalysisContext::from_statistics(&schema, &statistics);
- let predicate_boundaries = gt
- .analyze(context)
- .boundaries
- .expect("boundaries should not be None");
- assert_eq!(predicate_boundaries.selectivity, Some(0.76));
-
- Ok(())
- }
-
- #[test]
- fn test_binary_expression_boundaries_rhs() -> Result<()> {
- // This test is about the column rewriting feature in the boundary
provider
- // (e.g. if the lhs is a literal and rhs is the column, then we swap
them when
- // doing the computation).
-
- // A table where the column 'a' has a min of 1, a max of 100.
- let (schema, statistics) =
- get_test_table_stats(ScalarValue::from(1), ScalarValue::from(100));
-
- // expression: "50 >= a"
- let a = col("a", &schema).unwrap();
- let gt = binary(
- lit(ScalarValue::from(50)),
- Operator::GtEq,
- a.clone(),
- &schema,
- )?;
-
- let context = AnalysisContext::from_statistics(&schema, &statistics);
- let predicate_boundaries = gt
- .analyze(context)
- .boundaries
- .expect("boundaries should not be None");
- assert_eq!(predicate_boundaries.selectivity, Some(0.5));
-
- Ok(())
- }
-
#[test]
fn test_display_and_or_combo() {
let expr = BinaryExpr::new(
diff --git a/datafusion/physical-expr/src/expressions/column.rs
b/datafusion/physical-expr/src/expressions/column.rs
index 9eca9bf713..3b0d77b304 100644
--- a/datafusion/physical-expr/src/expressions/column.rs
+++ b/datafusion/physical-expr/src/expressions/column.rs
@@ -21,13 +21,13 @@ use std::any::Any;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
+use crate::physical_expr::down_cast_any_ref;
+use crate::PhysicalExpr;
+
use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
-
-use crate::physical_expr::down_cast_any_ref;
-use crate::{AnalysisContext, PhysicalExpr};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
@@ -104,13 +104,6 @@ impl PhysicalExpr for Column {
Ok(self)
}
- /// Return the boundaries of this column, if known.
- fn analyze(&self, context: AnalysisContext) -> AnalysisContext {
- assert!(self.index < context.column_boundaries.len());
- let col_bounds = context.column_boundaries[self.index].clone();
- context.with_boundaries(col_bounds)
- }
-
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
@@ -221,11 +214,13 @@ pub fn col(name: &str, schema: &Schema) -> Result<Arc<dyn
PhysicalExpr>> {
#[cfg(test)]
mod test {
use crate::expressions::Column;
- use crate::{AnalysisContext, ExprBoundaries, PhysicalExpr};
+ use crate::PhysicalExpr;
+
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
- use datafusion_common::{ColumnStatistics, Result, ScalarValue, Statistics};
+ use datafusion_common::Result;
+
use std::sync::Arc;
#[test]
@@ -263,76 +258,4 @@ mod test {
&format!("{error}"));
Ok(())
}
-
- /// Returns a pair of (schema, statistics) for a table of:
- /// - a => Stats(range=[1, 100], distinct=15)
- /// - b => unknown
- /// - c => Stats(range=[1, 100], distinct=unknown)
- fn get_test_table_stats() -> (Schema, Statistics) {
- let schema = Schema::new(vec![
- Field::new("a", DataType::Int32, true),
- Field::new("b", DataType::Int32, true),
- Field::new("c", DataType::Int32, true),
- ]);
-
- let columns = vec![
- ColumnStatistics {
- min_value: Some(ScalarValue::Int32(Some(1))),
- max_value: Some(ScalarValue::Int32(Some(100))),
- distinct_count: Some(15),
- ..Default::default()
- },
- ColumnStatistics::default(),
- ColumnStatistics {
- min_value: Some(ScalarValue::Int32(Some(1))),
- max_value: Some(ScalarValue::Int32(Some(75))),
- distinct_count: None,
- ..Default::default()
- },
- ];
-
- let statistics = Statistics {
- column_statistics: Some(columns),
- ..Default::default()
- };
-
- (schema, statistics)
- }
-
- #[test]
- fn stats_bounds_analysis() -> Result<()> {
- let (schema, statistics) = get_test_table_stats();
- let context = AnalysisContext::from_statistics(&schema, &statistics);
-
- let cases = [
- // (name, index, expected boundaries)
- (
- "a",
- 0,
- Some(ExprBoundaries::new(
- ScalarValue::Int32(Some(1)),
- ScalarValue::Int32(Some(100)),
- Some(15),
- )),
- ),
- ("b", 1, None),
- (
- "c",
- 2,
- Some(ExprBoundaries::new(
- ScalarValue::Int32(Some(1)),
- ScalarValue::Int32(Some(75)),
- None,
- )),
- ),
- ];
-
- for (name, index, expected) in cases {
- let col = Column::new(name, index);
- let test_ctx = col.analyze(context.clone());
- assert_eq!(test_ctx.boundaries, expected);
- }
-
- Ok(())
- }
}
diff --git a/datafusion/physical-expr/src/expressions/like.rs
b/datafusion/physical-expr/src/expressions/like.rs
index f549613acb..9523a4efd8 100644
--- a/datafusion/physical-expr/src/expressions/like.rs
+++ b/datafusion/physical-expr/src/expressions/like.rs
@@ -18,15 +18,7 @@
use std::hash::{Hash, Hasher};
use std::{any::Any, sync::Arc};
-use arrow::{
- array::{new_null_array, Array, ArrayRef, LargeStringArray, StringArray},
- record_batch::RecordBatch,
-};
-use arrow_schema::{DataType, Schema};
-use datafusion_common::{DataFusionError, Result, ScalarValue};
-use datafusion_expr::ColumnarValue;
-
-use crate::{physical_expr::down_cast_any_ref, AnalysisContext, PhysicalExpr};
+use crate::{physical_expr::down_cast_any_ref, PhysicalExpr};
use arrow::compute::kernels::comparison::{
ilike_utf8, like_utf8, nilike_utf8, nlike_utf8,
@@ -34,6 +26,13 @@ use arrow::compute::kernels::comparison::{
use arrow::compute::kernels::comparison::{
ilike_utf8_scalar, like_utf8_scalar, nilike_utf8_scalar, nlike_utf8_scalar,
};
+use arrow::{
+ array::{new_null_array, Array, ArrayRef, LargeStringArray, StringArray},
+ record_batch::RecordBatch,
+};
+use arrow_schema::{DataType, Schema};
+use datafusion_common::{DataFusionError, Result, ScalarValue};
+use datafusion_expr::ColumnarValue;
// Like expression
#[derive(Debug, Hash)]
@@ -183,11 +182,6 @@ impl PhysicalExpr for LikeExpr {
)))
}
- /// Return the boundaries of this binary expression's result.
- fn analyze(&self, context: AnalysisContext) -> AnalysisContext {
- context.with_boundaries(None)
- }
-
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
diff --git a/datafusion/physical-expr/src/expressions/literal.rs
b/datafusion/physical-expr/src/expressions/literal.rs
index 8cb2bd5b95..8e86716123 100644
--- a/datafusion/physical-expr/src/expressions/literal.rs
+++ b/datafusion/physical-expr/src/expressions/literal.rs
@@ -21,15 +21,14 @@ use std::any::Any;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
+use crate::physical_expr::down_cast_any_ref;
+use crate::PhysicalExpr;
+
use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
};
-
-use crate::physical_expr::down_cast_any_ref;
-use crate::{AnalysisContext, ExprBoundaries, PhysicalExpr};
-use datafusion_common::Result;
-use datafusion_common::ScalarValue;
+use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Expr};
/// Represents a literal value
@@ -85,16 +84,6 @@ impl PhysicalExpr for Literal {
Ok(self)
}
- /// Return the boundaries of this literal expression (which is the same as
- /// the value it represents).
- fn analyze(&self, context: AnalysisContext) -> AnalysisContext {
- context.with_boundaries(Some(ExprBoundaries::new(
- self.value.clone(),
- self.value.clone(),
- Some(1),
- )))
- }
-
fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.hash(&mut s);
@@ -148,20 +137,4 @@ mod tests {
Ok(())
}
-
- #[test]
- fn literal_bounds_analysis() -> Result<()> {
- let schema = Schema::empty();
- let context = AnalysisContext::new(&schema, vec![]);
-
- let literal_expr = lit(42i32);
- let result_ctx = literal_expr.analyze(context);
- let boundaries = result_ctx.boundaries.unwrap();
- assert_eq!(boundaries.min_value, ScalarValue::Int32(Some(42)));
- assert_eq!(boundaries.max_value, ScalarValue::Int32(Some(42)));
- assert_eq!(boundaries.distinct_count, Some(1));
- assert_eq!(boundaries.selectivity, None);
-
- Ok(())
- }
}
diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs
b/datafusion/physical-expr/src/intervals/cp_solver.rs
index a1698e6651..edf1507c70 100644
--- a/datafusion/physical-expr/src/intervals/cp_solver.rs
+++ b/datafusion/physical-expr/src/intervals/cp_solver.rs
@@ -119,7 +119,7 @@ use super::IntervalBound;
/// This object implements a directed acyclic expression graph (DAEG) that
/// is used to compute ranges for expressions through interval arithmetic.
-#[derive(Clone)]
+#[derive(Clone, Debug)]
pub struct ExprIntervalGraph {
graph: StableGraph<ExprIntervalGraphNode, usize>,
root: NodeIndex,
@@ -251,10 +251,10 @@ pub fn propagate_arithmetic(
}
/// This function provides a target parent interval for comparison operators.
-/// If we have expression > 0, expression must have the range [0, ∞].
-/// If we have expression < 0, expression must have the range [-∞, 0].
-/// Currently, we only support strict inequalities since open/closed intervals
-/// are not implemented yet.
+/// If we have expression > 0, expression must have the range (0, ∞).
+/// If we have expression >= 0, expression must have the range [0, ∞).
+/// If we have expression < 0, expression must have the range (-∞, 0).
+/// If we have expression <= 0, expression must have the range (-∞, 0].
fn comparison_operator_target(
left_datatype: &DataType,
op: &Operator,
@@ -268,6 +268,10 @@ fn comparison_operator_target(
Operator::Gt => Interval::new(IntervalBound::new(zero, true),
unbounded),
Operator::LtEq => Interval::new(unbounded, IntervalBound::new(zero,
false)),
Operator::Lt => Interval::new(unbounded, IntervalBound::new(zero,
true)),
+ Operator::Eq => Interval::new(
+ IntervalBound::new(zero.clone(), false),
+ IntervalBound::new(zero, false),
+ ),
_ => unreachable!(),
})
}
@@ -531,6 +535,11 @@ impl ExprIntervalGraph {
Ok(PropagationResult::CannotPropagate)
}
}
+
+ /// Returns the interval associated with the node at the given `index`.
+ pub fn get_interval(&self, index: usize) -> Interval {
+ self.graph[NodeIndex::new(index)].interval.clone()
+ }
}
/// Indicates whether interval arithmetic is supported for the given
expression.
diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
index 3e2b4697a1..f72006ab5c 100644
--- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
+++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
@@ -20,16 +20,18 @@
use std::borrow::Borrow;
use std::fmt;
use std::fmt::{Display, Formatter};
+use std::ops::{AddAssign, SubAssign};
+
+use crate::aggregate::min_max::{max, min};
+use crate::intervals::rounding::alter_fp_rounding_mode;
use arrow::compute::{cast_with_options, CastOptions};
use arrow::datatypes::DataType;
+use arrow_array::ArrowNativeTypeOp;
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::Operator;
-use crate::aggregate::min_max::{max, min};
-use crate::intervals::rounding::alter_fp_rounding_mode;
-
/// This type represents a single endpoint of an [`Interval`]. An endpoint can
/// be open or closed, denoting whether the interval includes or excludes the
/// endpoint itself.
@@ -451,6 +453,144 @@ impl Interval {
lower: IntervalBound::new(ScalarValue::Boolean(Some(true)), false),
upper: IntervalBound::new(ScalarValue::Boolean(Some(true)), false),
};
+
+ /// Returns the cardinality of this interval, which is the number of all
+ /// distinct points inside it.
+ pub fn cardinality(&self) -> Result<u64> {
+ match self.get_datatype() {
+ Ok(data_type) if data_type.is_integer() => {
+ if let Some(diff) =
self.upper.value.distance(&self.lower.value) {
+ Ok(calculate_cardinality_based_on_bounds(
+ self.lower.open,
+ self.upper.open,
+ diff as u64,
+ ))
+ } else {
+ Err(DataFusionError::Execution(format!(
+ "Cardinality cannot be calculated for {:?}",
+ self
+ )))
+ }
+ }
+ // Ordering floating-point numbers according to their binary
representations
+ // coincide with their natural ordering. Therefore, we can
consider their
+ // binary representations as "indices" and subtract them. For
details, see:
+ //
https://stackoverflow.com/questions/8875064/how-many-distinct-floating-point-numbers-in-a-specific-range
+ Ok(data_type) if data_type.is_floating() => {
+ // If the minimum value is a negative number, we need to
+ // switch sides to ensure an unsigned result.
+ let (min, max) = if self.lower.value
+ < ScalarValue::new_zero(&self.lower.value.get_datatype())?
+ {
+ (self.upper.value.clone(), self.lower.value.clone())
+ } else {
+ (self.lower.value.clone(), self.upper.value.clone())
+ };
+
+ match (min, max) {
+ (
+ ScalarValue::Float32(Some(lower)),
+ ScalarValue::Float32(Some(upper)),
+ ) => Ok(calculate_cardinality_based_on_bounds(
+ self.lower.open,
+ self.upper.open,
+ (upper.to_bits().sub_checked(lower.to_bits()))? as u64,
+ )),
+ (
+ ScalarValue::Float64(Some(lower)),
+ ScalarValue::Float64(Some(upper)),
+ ) => Ok(calculate_cardinality_based_on_bounds(
+ self.lower.open,
+ self.upper.open,
+ upper.to_bits().sub_checked(lower.to_bits())?,
+ )),
+ _ => Err(DataFusionError::Execution(format!(
+ "Cardinality cannot be calculated for the datatype
{:?}",
+ data_type
+ ))),
+ }
+ }
+ // If the cardinality cannot be calculated anyway, give an error.
+ _ => Err(DataFusionError::Execution(format!(
+ "Cardinality cannot be calculated for {:?}",
+ self
+ ))),
+ }
+ }
+
+ /// This function "closes" this interval; i.e. it modifies the endpoints so
+ /// that we end up with the narrowest possible closed interval containing
+ /// the original interval.
+ pub fn close_bounds(mut self) -> Interval {
+ if self.lower.open {
+ // Get next value
+ self.lower.value = next_value::<true>(self.lower.value);
+ self.lower.open = false;
+ }
+
+ if self.upper.open {
+ // Get previous value
+ self.upper.value = next_value::<false>(self.upper.value);
+ self.upper.open = false;
+ }
+
+ self
+ }
+}
+
+trait OneTrait: Sized + std::ops::Add + std::ops::Sub {
+ fn one() -> Self;
+}
+
+macro_rules! impl_OneTrait{
+ ($($m:ty),*) => {$( impl OneTrait for $m { fn one() -> Self { 1 as $m }
})*}
+}
+impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64}
+
+/// This function either increments or decrements its argument, depending on
the `INC` value.
+/// If `true`, it increments; otherwise it decrements the argument.
+fn increment_decrement<const INC: bool, T: OneTrait + SubAssign + AddAssign>(
+ mut val: T,
+) -> T {
+ if INC {
+ val.add_assign(T::one());
+ } else {
+ val.sub_assign(T::one());
+ }
+ val
+}
+
+/// This function returns the next/previous value depending on the `ADD` value.
+/// If `true`, it returns the next value; otherwise it returns the previous
value.
+fn next_value<const INC: bool>(value: ScalarValue) -> ScalarValue {
+ use ScalarValue::*;
+ match value {
+ Float32(Some(val)) => {
+ let incremented_bits = increment_decrement::<INC,
u32>(val.to_bits());
+ Float32(Some(f32::from_bits(incremented_bits)))
+ }
+ Float64(Some(val)) => {
+ let incremented_bits = increment_decrement::<INC,
u64>(val.to_bits());
+ Float64(Some(f64::from_bits(incremented_bits)))
+ }
+ Int8(Some(val)) => Int8(Some(increment_decrement::<INC, i8>(val))),
+ Int16(Some(val)) => Int16(Some(increment_decrement::<INC, i16>(val))),
+ Int32(Some(val)) => Int32(Some(increment_decrement::<INC, i32>(val))),
+ Int64(Some(val)) => Int64(Some(increment_decrement::<INC, i64>(val))),
+ UInt8(Some(val)) => UInt8(Some(increment_decrement::<INC, u8>(val))),
+ UInt16(Some(val)) => UInt16(Some(increment_decrement::<INC,
u16>(val))),
+ UInt32(Some(val)) => UInt32(Some(increment_decrement::<INC,
u32>(val))),
+ UInt64(Some(val)) => UInt64(Some(increment_decrement::<INC,
u64>(val))),
+ _ => value, // Infinite bounds or unsupported datatypes
+ }
+}
+
+/// This function computes the cardinality ratio of the given intervals.
+pub fn cardinality_ratio(
+ initial_interval: &Interval,
+ final_interval: &Interval,
+) -> Result<f64> {
+ Ok(final_interval.cardinality()? as f64 / initial_interval.cardinality()?
as f64)
}
/// Indicates whether interval arithmetic is supported for the given operator.
@@ -464,6 +604,7 @@ pub fn is_operator_supported(op: &Operator) -> bool {
| &Operator::GtEq
| &Operator::Lt
| &Operator::LtEq
+ | &Operator::Eq
)
}
@@ -508,11 +649,26 @@ fn cast_scalar_value(
ScalarValue::try_from_array(&cast_array, 0)
}
+/// This function calculates the final cardinality result by inspecting the
endpoints of the interval.
+fn calculate_cardinality_based_on_bounds(
+ lower_open: bool,
+ upper_open: bool,
+ diff: u64,
+) -> u64 {
+ match (lower_open, upper_open) {
+ (false, false) => diff + 1,
+ (true, true) => diff - 1,
+ _ => diff,
+ }
+}
+
#[cfg(test)]
mod tests {
+ use super::next_value;
use crate::intervals::{Interval, IntervalBound};
+
+ use arrow_schema::DataType;
use datafusion_common::{Result, ScalarValue};
- use ScalarValue::Boolean;
fn open_open<T>(lower: Option<T>, upper: Option<T>) -> Interval
where
@@ -1060,6 +1216,7 @@ mod tests {
// ([false, false], [false, true], [true, true]) for boolean intervals.
#[test]
fn non_standard_interval_constructs() {
+ use ScalarValue::Boolean;
let cases = vec![
(
IntervalBound::new(Boolean(None), true),
@@ -1172,4 +1329,138 @@ mod tests {
let upper = 1.5;
capture_mode_change_f32((lower, upper), true, true);
}
+
+ #[test]
+ fn test_cardinality_of_intervals() -> Result<()> {
+ // In IEEE 754 standard for floating-point arithmetic, if we keep the
sign and exponent fields same,
+ // we can represent 4503599627370496 different numbers by changing the
mantissa
+ // (4503599627370496 = 2^52, since there are 52 bits in mantissa, and
2^23 = 8388608 for f32).
+ let distinct_f64 = 4503599627370496;
+ let distinct_f32 = 8388608;
+ let intervals = [
+ Interval::new(
+ IntervalBound::new(ScalarValue::from(0.25), false),
+ IntervalBound::new(ScalarValue::from(0.50), true),
+ ),
+ Interval::new(
+ IntervalBound::new(ScalarValue::from(0.5), false),
+ IntervalBound::new(ScalarValue::from(1.0), true),
+ ),
+ Interval::new(
+ IntervalBound::new(ScalarValue::from(1.0), false),
+ IntervalBound::new(ScalarValue::from(2.0), true),
+ ),
+ Interval::new(
+ IntervalBound::new(ScalarValue::from(32.0), false),
+ IntervalBound::new(ScalarValue::from(64.0), true),
+ ),
+ Interval::new(
+ IntervalBound::new(ScalarValue::from(-0.50), false),
+ IntervalBound::new(ScalarValue::from(-0.25), true),
+ ),
+ Interval::new(
+ IntervalBound::new(ScalarValue::from(-32.0), false),
+ IntervalBound::new(ScalarValue::from(-16.0), true),
+ ),
+ ];
+ for interval in intervals {
+ assert_eq!(interval.cardinality()?, distinct_f64);
+ }
+
+ let intervals = [
+ Interval::new(
+ IntervalBound::new(ScalarValue::from(0.25_f32), false),
+ IntervalBound::new(ScalarValue::from(0.50_f32), true),
+ ),
+ Interval::new(
+ IntervalBound::new(ScalarValue::from(-1_f32), false),
+ IntervalBound::new(ScalarValue::from(-0.5_f32), true),
+ ),
+ ];
+ for interval in intervals {
+ assert_eq!(interval.cardinality()?, distinct_f32);
+ }
+
+ let interval = Interval::new(
+ IntervalBound::new(ScalarValue::from(-0.0625), false),
+ IntervalBound::new(ScalarValue::from(0.0625), true),
+ );
+ assert_eq!(interval.cardinality()?, distinct_f64 * 2_048);
+
+ let interval = Interval::new(
+ IntervalBound::new(ScalarValue::from(-0.0625_f32), false),
+ IntervalBound::new(ScalarValue::from(0.0625_f32), true),
+ );
+ assert_eq!(interval.cardinality()?, distinct_f32 * 256);
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_next_value() -> Result<()> {
+ // integer increment / decrement
+ let zeros = vec![
+ ScalarValue::new_zero(&DataType::UInt8)?,
+ ScalarValue::new_zero(&DataType::UInt16)?,
+ ScalarValue::new_zero(&DataType::UInt32)?,
+ ScalarValue::new_zero(&DataType::UInt64)?,
+ ScalarValue::new_zero(&DataType::Int8)?,
+ ScalarValue::new_zero(&DataType::Int8)?,
+ ScalarValue::new_zero(&DataType::Int8)?,
+ ScalarValue::new_zero(&DataType::Int8)?,
+ ];
+
+ let ones = vec![
+ ScalarValue::new_one(&DataType::UInt8)?,
+ ScalarValue::new_one(&DataType::UInt16)?,
+ ScalarValue::new_one(&DataType::UInt32)?,
+ ScalarValue::new_one(&DataType::UInt64)?,
+ ScalarValue::new_one(&DataType::Int8)?,
+ ScalarValue::new_one(&DataType::Int8)?,
+ ScalarValue::new_one(&DataType::Int8)?,
+ ScalarValue::new_one(&DataType::Int8)?,
+ ];
+
+ let _ = zeros.into_iter().zip(ones.into_iter()).map(|(z, o)| {
+ assert_eq!(next_value::<true>(z.clone()), o);
+ assert_eq!(next_value::<false>(o), z);
+ });
+
+ // floating value increment / decrement
+ let values = vec![
+ ScalarValue::new_zero(&DataType::Float32)?,
+ ScalarValue::new_zero(&DataType::Float64)?,
+ ];
+
+ let eps = vec![
+ ScalarValue::Float32(Some(1e-6)),
+ ScalarValue::Float64(Some(1e-6)),
+ ];
+
+ let _ = values.into_iter().zip(eps.into_iter()).map(|(v, e)| {
+
assert!(next_value::<true>(v.clone()).sub(v.clone()).unwrap().lt(&e));
+ assert!(v.clone().sub(next_value::<false>(v)).unwrap().lt(&e));
+ });
+
+ // Min / Max values do not change
+ let min = vec![
+ ScalarValue::UInt64(Some(u64::MIN)),
+ ScalarValue::Int8(Some(i8::MIN)),
+ ScalarValue::Float32(Some(f32::MIN)),
+ ScalarValue::Float64(Some(f64::MIN)),
+ ];
+ let max = vec![
+ ScalarValue::UInt64(Some(u64::MAX)),
+ ScalarValue::Int8(Some(i8::MAX)),
+ ScalarValue::Float32(Some(f32::MAX)),
+ ScalarValue::Float64(Some(f64::MAX)),
+ ];
+
+ let _ = min.into_iter().zip(max.into_iter()).map(|(min, max)| {
+ assert_eq!(next_value::<true>(max.clone()), max);
+ assert_eq!(next_value::<false>(min.clone()), min);
+ });
+
+ Ok(())
+ }
}
diff --git a/datafusion/physical-expr/src/lib.rs
b/datafusion/physical-expr/src/lib.rs
index faa805dc92..4df5045a63 100644
--- a/datafusion/physical-expr/src/lib.rs
+++ b/datafusion/physical-expr/src/lib.rs
@@ -56,7 +56,9 @@ pub use equivalence::{
EquivalenceProperties, EquivalentClass, OrderingEquivalenceProperties,
OrderingEquivalentClass,
};
-pub use physical_expr::{AnalysisContext, ExprBoundaries, PhysicalExpr,
PhysicalExprRef};
+pub use physical_expr::{
+ analyze, AnalysisContext, ExprBoundaries, PhysicalExpr, PhysicalExprRef,
+};
pub use planner::create_physical_expr;
pub use scalar_function::ScalarFunctionExpr;
pub use sort_expr::{
diff --git a/datafusion/physical-expr/src/physical_expr.rs
b/datafusion/physical-expr/src/physical_expr.rs
index 68525920a0..19503d3a4c 100644
--- a/datafusion/physical-expr/src/physical_expr.rs
+++ b/datafusion/physical-expr/src/physical_expr.rs
@@ -15,24 +15,21 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::datatypes::{DataType, Schema};
+use crate::expressions::Column;
+use crate::intervals::cp_solver::PropagationResult;
+use crate::intervals::{cardinality_ratio, ExprIntervalGraph, Interval,
IntervalBound};
+use crate::utils::collect_columns;
+use arrow::array::{make_array, Array, ArrayRef, BooleanArray,
MutableArrayData};
+use arrow::compute::{and_kleene, filter_record_batch, is_not_null,
SlicesIterator};
+use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
-
use datafusion_common::utils::DataPtr;
-use datafusion_common::{
- ColumnStatistics, DataFusionError, Result, ScalarValue, Statistics,
-};
+use datafusion_common::{ColumnStatistics, DataFusionError, Result,
ScalarValue};
use datafusion_expr::ColumnarValue;
-use std::cmp::Ordering;
-use std::fmt::{Debug, Display};
-
-use arrow::array::{make_array, Array, ArrayRef, BooleanArray,
MutableArrayData};
-use arrow::compute::{and_kleene, filter_record_batch, is_not_null,
SlicesIterator};
-
-use crate::intervals::Interval;
use std::any::Any;
+use std::fmt::{Debug, Display};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
@@ -79,12 +76,6 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug +
PartialEq<dyn Any> {
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>>;
- /// Return the boundaries of this expression. This method (and all the
- /// related APIs) are experimental and subject to change.
- fn analyze(&self, context: AnalysisContext) -> AnalysisContext {
- context
- }
-
/// Computes bounds for the expression using interval arithmetic.
fn evaluate_bounds(&self, _children: &[&Interval]) -> Result<Interval> {
Err(DataFusionError::NotImplemented(format!(
@@ -139,6 +130,143 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug +
PartialEq<dyn Any> {
fn dyn_hash(&self, _state: &mut dyn Hasher);
}
+/// Attempts to refine column boundaries and compute a selectivity value.
+///
+/// The function accepts boundaries of the input columns in the `context`
parameter.
+/// It then tries to tighten these boundaries based on the provided `expr`.
+/// The resulting selectivity value is calculated by comparing the initial and
final boundaries.
+/// The computation assumes that the data within the column is uniformly
distributed and not sorted.
+///
+/// # Arguments
+///
+/// * `context` - The context holding input column boundaries.
+/// * `expr` - The expression used to shrink the column boundaries.
+///
+/// # Returns
+///
+/// * `AnalysisContext` constructed by pruned boundaries and a selectivity
value.
+pub fn analyze(
+ expr: &Arc<dyn PhysicalExpr>,
+ context: AnalysisContext,
+) -> Result<AnalysisContext> {
+ let target_boundaries = context.boundaries.ok_or_else(|| {
+ DataFusionError::Internal("No column exists at the input to
filter".to_string())
+ })?;
+
+ let mut graph = ExprIntervalGraph::try_new(expr.clone())?;
+
+ let columns: Vec<Arc<dyn PhysicalExpr>> = collect_columns(expr)
+ .into_iter()
+ .map(|c| Arc::new(c) as Arc<dyn PhysicalExpr>)
+ .collect();
+
+ let target_expr_and_indices: Vec<(Arc<dyn PhysicalExpr>, usize)> =
+ graph.gather_node_indices(columns.as_slice());
+
+ let mut target_indices_and_boundaries: Vec<(usize, Interval)> =
+ target_expr_and_indices
+ .iter()
+ .filter_map(|(expr, i)| {
+ target_boundaries.iter().find_map(|bound| {
+ expr.as_any()
+ .downcast_ref::<Column>()
+ .filter(|expr_column| bound.column.eq(*expr_column))
+ .map(|_| (*i, bound.interval.clone()))
+ })
+ })
+ .collect();
+
+ match graph.update_ranges(&mut target_indices_and_boundaries)? {
+ PropagationResult::Success => {
+ shrink_boundaries(expr, graph, target_boundaries,
target_expr_and_indices)
+ }
+ PropagationResult::Infeasible => {
+ Ok(AnalysisContext::new(target_boundaries).with_selectivity(0.0))
+ }
+ PropagationResult::CannotPropagate => {
+ Ok(AnalysisContext::new(target_boundaries).with_selectivity(1.0))
+ }
+ }
+}
+
+/// If the `PropagationResult` indicates success, this function calculates the
+/// selectivity value by comparing the initial and final column boundaries.
+/// Following this, it constructs and returns a new `AnalysisContext` with the
+/// updated parameters.
+fn shrink_boundaries(
+ expr: &Arc<dyn PhysicalExpr>,
+ mut graph: ExprIntervalGraph,
+ mut target_boundaries: Vec<ExprBoundaries>,
+ target_expr_and_indices: Vec<(Arc<dyn PhysicalExpr>, usize)>,
+) -> Result<AnalysisContext> {
+ let initial_boundaries = target_boundaries.clone();
+ target_expr_and_indices.iter().for_each(|(expr, i)| {
+ if let Some(column) = expr.as_any().downcast_ref::<Column>() {
+ if let Some(bound) = target_boundaries
+ .iter_mut()
+ .find(|bound| bound.column.eq(column))
+ {
+ bound.interval = graph.get_interval(*i);
+ };
+ }
+ });
+ let graph_nodes = graph.gather_node_indices(&[expr.clone()]);
+ let (_, root_index) = graph_nodes.first().ok_or_else(|| {
+ DataFusionError::Internal("Error in constructing predicate
graph".to_string())
+ })?;
+ let final_result = graph.get_interval(*root_index);
+
+ let selectivity = calculate_selectivity(
+ &final_result.lower.value,
+ &final_result.upper.value,
+ &target_boundaries,
+ &initial_boundaries,
+ )?;
+
+ if !(0.0..=1.0).contains(&selectivity) {
+ return Err(DataFusionError::Internal(format!(
+ "Selectivity is out of limit: {}",
+ selectivity
+ )));
+ }
+
+ Ok(AnalysisContext::new(target_boundaries).with_selectivity(selectivity))
+}
+
+/// This function calculates the filter predicate's selectivity by comparing
+/// the initial and pruned column boundaries. Selectivity is defined as the
+/// ratio of rows in a table that satisfy the filter's predicate.
+///
+/// An exact propagation result at the root, i.e. `[true, true]` or `[false,
false]`,
+/// leads to early exit (returning a selectivity value of either 1.0 or 0.0).
In such
+/// a case, `[true, true]` indicates that all data values satisfy the
predicate (hence,
+/// selectivity is 1.0), and `[false, false]` suggests that no data value
meets the
+/// predicate (therefore, selectivity is 0.0).
+fn calculate_selectivity(
+ lower_value: &ScalarValue,
+ upper_value: &ScalarValue,
+ target_boundaries: &[ExprBoundaries],
+ initial_boundaries: &[ExprBoundaries],
+) -> Result<f64> {
+ match (lower_value, upper_value) {
+ (ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(true)))
=> Ok(1.0),
+ (ScalarValue::Boolean(Some(false)), ScalarValue::Boolean(Some(false)))
=> Ok(0.0),
+ _ => {
+ // Since the intervals are assumed uniform and the values
+ // are not correlated, we need to multiply the selectivities
+ // of multiple columns to get the overall selectivity.
+ target_boundaries.iter().enumerate().try_fold(
+ 1.0,
+ |acc, (i, ExprBoundaries { interval, .. })| {
+ let temp =
+ cardinality_ratio(&initial_boundaries[i].interval,
interval)?;
+ Ok(acc * temp)
+ },
+ )
+ }
+ }
+}
+
impl Hash for dyn PhysicalExpr {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
@@ -152,58 +280,42 @@ pub type PhysicalExprRef = Arc<dyn PhysicalExpr>;
/// the boundaries for all known columns.
#[derive(Clone, Debug, PartialEq)]
pub struct AnalysisContext {
- /// A list of known column boundaries, ordered by the index
- /// of the column in the current schema.
- pub column_boundaries: Vec<Option<ExprBoundaries>>,
- // Result of the current analysis.
- pub boundaries: Option<ExprBoundaries>,
+ // A list of known column boundaries, ordered by the index
+ // of the column in the current schema.
+ pub boundaries: Option<Vec<ExprBoundaries>>,
+ /// The estimated percentage of rows that this expression would select, if
+ /// it were to be used as a boolean predicate on a filter. The value will
be
+ /// between 0.0 (selects nothing) and 1.0 (selects everything).
+ pub selectivity: Option<f64>,
}
impl AnalysisContext {
- pub fn new(
- input_schema: &Schema,
- column_boundaries: Vec<Option<ExprBoundaries>>,
- ) -> Self {
- assert_eq!(input_schema.fields().len(), column_boundaries.len());
+ pub fn new(boundaries: Vec<ExprBoundaries>) -> Self {
Self {
- column_boundaries,
- boundaries: None,
+ boundaries: Some(boundaries),
+ selectivity: None,
}
}
- /// Create a new analysis context from column statistics.
- pub fn from_statistics(input_schema: &Schema, statistics: &Statistics) ->
Self {
- // Even if the underlying statistics object doesn't have any column
level statistics,
- // we can still create an analysis context with the same number of
columns and see whether
- // we can infer it during the way.
- let column_boundaries = match &statistics.column_statistics {
- Some(columns) => columns
- .iter()
- .map(ExprBoundaries::from_column)
- .collect::<Vec<_>>(),
- None => vec![None; input_schema.fields().len()],
- };
- Self::new(input_schema, column_boundaries)
- }
-
- pub fn boundaries(&self) -> Option<&ExprBoundaries> {
- self.boundaries.as_ref()
- }
-
- /// Set the result of the current analysis.
- pub fn with_boundaries(mut self, result: Option<ExprBoundaries>) -> Self {
- self.boundaries = result;
+ pub fn with_selectivity(mut self, selectivity: f64) -> Self {
+ self.selectivity = Some(selectivity);
self
}
- /// Update the boundaries of a column.
- pub fn with_column_update(
- mut self,
- column: usize,
- boundaries: ExprBoundaries,
+ /// Create a new analysis context from column statistics.
+ pub fn from_statistics(
+ input_schema: &Schema,
+ statistics: &[ColumnStatistics],
) -> Self {
- self.column_boundaries[column] = Some(boundaries);
- self
+ let mut column_boundaries = vec![];
+ for (idx, stats) in statistics.iter().enumerate() {
+ column_boundaries.push(ExprBoundaries::from_column(
+ stats,
+ input_schema.fields()[idx].name().clone(),
+ idx,
+ ));
+ }
+ Self::new(column_boundaries)
}
}
@@ -211,64 +323,29 @@ impl AnalysisContext {
/// if it were to be an expression, if it were to be evaluated.
#[derive(Clone, Debug, PartialEq)]
pub struct ExprBoundaries {
- /// Minimum value this expression's result can have.
- pub min_value: ScalarValue,
- /// Maximum value this expression's result can have.
- pub max_value: ScalarValue,
+ pub column: Column,
+ /// Minimum and maximum values this expression can have.
+ pub interval: Interval,
/// Maximum number of distinct values this expression can produce, if
known.
pub distinct_count: Option<usize>,
- /// The estimated percantage of rows that this expression would select, if
- /// it were to be used as a boolean predicate on a filter. The value will
be
- /// between 0.0 (selects nothing) and 1.0 (selects everything).
- pub selectivity: Option<f64>,
}
impl ExprBoundaries {
- /// Create a new `ExprBoundaries`.
- pub fn new(
- min_value: ScalarValue,
- max_value: ScalarValue,
- distinct_count: Option<usize>,
- ) -> Self {
- Self::new_with_selectivity(min_value, max_value, distinct_count, None)
- }
-
- /// Create a new `ExprBoundaries` with a selectivity value.
- pub fn new_with_selectivity(
- min_value: ScalarValue,
- max_value: ScalarValue,
- distinct_count: Option<usize>,
- selectivity: Option<f64>,
- ) -> Self {
- assert!(!matches!(
- min_value.partial_cmp(&max_value),
- Some(Ordering::Greater)
- ));
+ /// Create a new `ExprBoundaries` object from column level statistics.
+ pub fn from_column(stats: &ColumnStatistics, col: String, index: usize) ->
Self {
Self {
- min_value,
- max_value,
- distinct_count,
- selectivity,
- }
- }
-
- /// Create a new `ExprBoundaries` from a column level statistics.
- pub fn from_column(column: &ColumnStatistics) -> Option<Self> {
- Some(Self {
- min_value: column.min_value.clone()?,
- max_value: column.max_value.clone()?,
- distinct_count: column.distinct_count,
- selectivity: None,
- })
- }
-
- /// Try to reduce the boundaries into a single scalar value, if possible.
- pub fn reduce(&self) -> Option<ScalarValue> {
- // TODO: should we check distinct_count is `Some(1) | None`?
- if self.min_value == self.max_value {
- Some(self.min_value.clone())
- } else {
- None
+ column: Column::new(&col, index),
+ interval: Interval::new(
+ IntervalBound::new(
+ stats.min_value.clone().unwrap_or(ScalarValue::Null),
+ false,
+ ),
+ IntervalBound::new(
+ stats.max_value.clone().unwrap_or(ScalarValue::Null),
+ false,
+ ),
+ ),
+ distinct_count: stats.distinct_count,
}
}
}
@@ -360,7 +437,7 @@ macro_rules! analysis_expect {
($context: ident, $expr: expr) => {
match $expr {
Some(expr) => expr,
- None => return $context.with_boundaries(None),
+ None => return Ok($context.with_boundaries(None)),
}
};
}
@@ -441,31 +518,4 @@ mod tests {
assert_eq!(&expected, result);
Ok(())
}
-
- #[test]
- fn reduce_boundaries() -> Result<()> {
- let different_boundaries = ExprBoundaries::new(
- ScalarValue::Int32(Some(1)),
- ScalarValue::Int32(Some(10)),
- None,
- );
- assert_eq!(different_boundaries.reduce(), None);
-
- let scalar_boundaries = ExprBoundaries::new(
- ScalarValue::Int32(Some(1)),
- ScalarValue::Int32(Some(1)),
- None,
- );
- assert_eq!(
- scalar_boundaries.reduce(),
- Some(ScalarValue::Int32(Some(1)))
- );
-
- // Can still reduce.
- let no_boundaries =
- ExprBoundaries::new(ScalarValue::Int32(None),
ScalarValue::Int32(None), None);
- assert_eq!(no_boundaries.reduce(), Some(ScalarValue::Int32(None)));
-
- Ok(())
- }
}