This is an automated email from the ASF dual-hosted git repository.
wjones127 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 8946f8bd34 feat: add guarantees to simplification (#7467)
8946f8bd34 is described below
commit 8946f8bd34e3d29009b5cbe41da8ef4d04af421a
Author: Will Jones <[email protected]>
AuthorDate: Wed Sep 13 14:09:51 2023 -0400
feat: add guarantees to simplification (#7467)
* feat: add guarantees to simplifcation
* null and comparison support
* add support for literal expressions
* implement inlist guarantee use
* test the outer function
* docs
* refactor to use intervals
* add high-level test
* cleanup
* fix test to be false or null, not true
* refactor: change NullableInterval to an enum
* refactor: use a builder-like API
* pr feedback
* Fix clippy
* fix doc links
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
.../src/simplify_expressions/expr_simplifier.rs | 177 ++++++-
.../src/simplify_expressions/guarantees.rs | 520 +++++++++++++++++++++
.../optimizer/src/simplify_expressions/mod.rs | 1 +
.../src/intervals/interval_aritmetic.rs | 311 ++++++++++++
4 files changed, 1006 insertions(+), 3 deletions(-)
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index c92660c7bb..f5a6860299 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -39,13 +39,20 @@ use datafusion_expr::{
and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case,
ColumnarValue, Expr,
Like, Volatility,
};
-use datafusion_physical_expr::{create_physical_expr,
execution_props::ExecutionProps};
+use datafusion_physical_expr::{
+ create_physical_expr, execution_props::ExecutionProps,
intervals::NullableInterval,
+};
use crate::simplify_expressions::SimplifyInfo;
+use crate::simplify_expressions::guarantees::GuaranteeRewriter;
+
/// This structure handles API for expression simplification
pub struct ExprSimplifier<S> {
info: S,
+ /// Guarantees about the values of columns. This is provided by the user
+ /// in [ExprSimplifier::with_guarantees()].
+ guarantees: Vec<(Expr, NullableInterval)>,
}
pub const THRESHOLD_INLINE_INLIST: usize = 3;
@@ -57,7 +64,10 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
///
/// [`SimplifyContext`]:
crate::simplify_expressions::context::SimplifyContext
pub fn new(info: S) -> Self {
- Self { info }
+ Self {
+ info,
+ guarantees: vec![],
+ }
}
/// Simplifies this [`Expr`]`s as much as possible, evaluating
@@ -121,6 +131,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator =
ConstEvaluator::try_new(self.info.execution_props())?;
let mut or_in_list_simplifier = OrInListSimplifier::new();
+ let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
// TODO iterate until no changes are made during rewrite
// (evaluating constants can enable new simplifications and
@@ -129,6 +140,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
expr.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)?
.rewrite(&mut or_in_list_simplifier)?
+ .rewrite(&mut guarantee_rewriter)?
// run both passes twice to try an minimize simplifications that
we missed
.rewrite(&mut const_evaluator)?
.rewrite(&mut simplifier)
@@ -149,6 +161,65 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
expr.rewrite(&mut expr_rewrite)
}
+
+ /// Input guarantees about the values of columns.
+ ///
+ /// The guarantees can simplify expressions. For example, if a column `x`
is
+ /// guaranteed to be `3`, then the expression `x > 1` can be replaced by
the
+ /// literal `true`.
+ ///
+ /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
+ /// where the [Expr] is a column reference and the [NullableInterval]
+ /// is an interval representing the known possible values of that column.
+ ///
+ /// ```rust
+ /// use arrow::datatypes::{DataType, Field, Schema};
+ /// use datafusion_expr::{col, lit, Expr};
+ /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
+ /// use datafusion_physical_expr::execution_props::ExecutionProps;
+ /// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
+ /// use datafusion_optimizer::simplify_expressions::{
+ /// ExprSimplifier, SimplifyContext};
+ ///
+ /// let schema = Schema::new(vec![
+ /// Field::new("x", DataType::Int64, false),
+ /// Field::new("y", DataType::UInt32, false),
+ /// Field::new("z", DataType::Int64, false),
+ /// ])
+ /// .to_dfschema_ref().unwrap();
+ ///
+ /// // Create the simplifier
+ /// let props = ExecutionProps::new();
+ /// let context = SimplifyContext::new(&props)
+ /// .with_schema(schema);
+ ///
+ /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
+ /// let expr_x = col("x").gt_eq(lit(3_i64));
+ /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
+ /// let expr_z = col("z").gt(lit(5_i64));
+ /// let expr = expr_x.and(expr_y).and(expr_z.clone());
+ ///
+ /// let guarantees = vec![
+ /// // x ∈ [3, 5]
+ /// (
+ /// col("x"),
+ /// NullableInterval::NotNull {
+ /// values: Interval::make(Some(3_i64), Some(5_i64), (false,
false)),
+ /// }
+ /// ),
+ /// // y = 3
+ /// (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
+ /// ];
+ /// let simplifier =
ExprSimplifier::new(context).with_guarantees(guarantees);
+ /// let output = simplifier.simplify(expr).unwrap();
+ /// // Expression becomes: true AND true AND (z > 5), which simplifies to
+ /// // z > 5.
+ /// assert_eq!(output, expr_z);
+ /// ```
+ pub fn with_guarantees(mut self, guarantees: Vec<(Expr,
NullableInterval)>) -> Self {
+ self.guarantees = guarantees;
+ self
+ }
}
#[allow(rustdoc::private_intra_doc_links)]
@@ -1239,7 +1310,9 @@ mod tests {
use datafusion_common::{assert_contains, cast::as_int32_array, DFField,
ToDFSchema};
use datafusion_expr::*;
use datafusion_physical_expr::{
- execution_props::ExecutionProps, functions::make_scalar_function,
+ execution_props::ExecutionProps,
+ functions::make_scalar_function,
+ intervals::{Interval, NullableInterval},
};
// ------------------------------
@@ -2703,6 +2776,19 @@ mod tests {
try_simplify(expr).unwrap()
}
+ fn simplify_with_guarantee(
+ expr: Expr,
+ guarantees: Vec<(Expr, NullableInterval)>,
+ ) -> Expr {
+ let schema = expr_test_schema();
+ let execution_props = ExecutionProps::new();
+ let simplifier = ExprSimplifier::new(
+ SimplifyContext::new(&execution_props).with_schema(schema),
+ )
+ .with_guarantees(guarantees);
+ simplifier.simplify(expr).unwrap()
+ }
+
fn expr_test_schema() -> DFSchemaRef {
Arc::new(
DFSchema::new_with_metadata(
@@ -3166,4 +3252,89 @@ mod tests {
let expr = not_ilike(null, "%");
assert_eq!(simplify(expr), lit_bool_null());
}
+
+ #[test]
+ fn test_simplify_with_guarantee() {
+ // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
+ let expr_x = col("c3").gt(lit(3_i64));
+ let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
+ let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
+ let expr = expr_x.clone().and(expr_y.clone().or(expr_z));
+
+ // All guaranteed null
+ let guarantees = vec![
+ (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
+ (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
+ (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
+ ];
+
+ let output = simplify_with_guarantee(expr.clone(), guarantees);
+ assert_eq!(output, lit_bool_null());
+
+ // All guaranteed false
+ let guarantees = vec![
+ (
+ col("c3"),
+ NullableInterval::NotNull {
+ values: Interval::make(Some(0_i64), Some(2_i64), (false,
false)),
+ },
+ ),
+ (
+ col("c4"),
+ NullableInterval::from(ScalarValue::UInt32(Some(9))),
+ ),
+ (
+ col("c1"),
+
NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))),
+ ),
+ ];
+ let output = simplify_with_guarantee(expr.clone(), guarantees);
+ assert_eq!(output, lit(false));
+
+ // Guaranteed false or null -> no change.
+ let guarantees = vec![
+ (
+ col("c3"),
+ NullableInterval::MaybeNull {
+ values: Interval::make(Some(0_i64), Some(2_i64), (false,
false)),
+ },
+ ),
+ (
+ col("c4"),
+ NullableInterval::MaybeNull {
+ values: Interval::make(Some(9_u32), Some(9_u32), (false,
false)),
+ },
+ ),
+ (
+ col("c1"),
+ NullableInterval::NotNull {
+ values: Interval::make(Some("d"), Some("f"), (false,
false)),
+ },
+ ),
+ ];
+ let output = simplify_with_guarantee(expr.clone(), guarantees);
+ assert_eq!(&output, &expr_x);
+
+ // Sufficient true guarantees
+ let guarantees = vec![
+ (
+ col("c3"),
+ NullableInterval::from(ScalarValue::Int64(Some(9))),
+ ),
+ (
+ col("c4"),
+ NullableInterval::from(ScalarValue::UInt32(Some(3))),
+ ),
+ ];
+ let output = simplify_with_guarantee(expr.clone(), guarantees);
+ assert_eq!(output, lit(true));
+
+ // Only partially simplify
+ let guarantees = vec![(
+ col("c4"),
+ NullableInterval::from(ScalarValue::UInt32(Some(3))),
+ )];
+ let output = simplify_with_guarantee(expr.clone(), guarantees);
+ assert_eq!(&output, &expr_x);
+ }
}
diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs
b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
new file mode 100644
index 0000000000..5504d7d76e
--- /dev/null
+++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
@@ -0,0 +1,520 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`]
+//!
+//! [`ExprSimplifier::with_guarantees()`]:
crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
+use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result};
+use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr};
+use std::collections::HashMap;
+
+use datafusion_physical_expr::intervals::{Interval, IntervalBound,
NullableInterval};
+
+/// Rewrite expressions to incorporate guarantees.
+///
+/// Guarantees are a mapping from an expression (which currently is always a
+/// column reference) to a [NullableInterval]. The interval represents the
known
+/// possible values of the column. Using these known values, expressions are
+/// rewritten so they can be simplified using `ConstEvaluator` and
`Simplifier`.
+///
+/// For example, if we know that a column is not null and has values in the
+/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
+///
+/// See a full example in [`ExprSimplifier::with_guarantees()`].
+///
+/// [`ExprSimplifier::with_guarantees()`]:
crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
+pub(crate) struct GuaranteeRewriter<'a> {
+ guarantees: HashMap<&'a Expr, &'a NullableInterval>,
+}
+
+impl<'a> GuaranteeRewriter<'a> {
+ pub fn new(
+ guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
+ ) -> Self {
+ Self {
+ guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
+ }
+ }
+}
+
+impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
+ type N = Expr;
+
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ if self.guarantees.is_empty() {
+ return Ok(expr);
+ }
+
+ match &expr {
+ Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) {
+ Some(NullableInterval::Null { .. }) => Ok(lit(true)),
+ Some(NullableInterval::NotNull { .. }) => Ok(lit(false)),
+ _ => Ok(expr),
+ },
+ Expr::IsNotNull(inner) => match
self.guarantees.get(inner.as_ref()) {
+ Some(NullableInterval::Null { .. }) => Ok(lit(false)),
+ Some(NullableInterval::NotNull { .. }) => Ok(lit(true)),
+ _ => Ok(expr),
+ },
+ Expr::Between(Between {
+ expr: inner,
+ negated,
+ low,
+ high,
+ }) => {
+ if let (Some(interval), Expr::Literal(low),
Expr::Literal(high)) = (
+ self.guarantees.get(inner.as_ref()),
+ low.as_ref(),
+ high.as_ref(),
+ ) {
+ let expr_interval = NullableInterval::NotNull {
+ values: Interval::new(
+ IntervalBound::new(low.clone(), false),
+ IntervalBound::new(high.clone(), false),
+ ),
+ };
+
+ let contains = expr_interval.contains(*interval)?;
+
+ if contains.is_certainly_true() {
+ Ok(lit(!negated))
+ } else if contains.is_certainly_false() {
+ Ok(lit(*negated))
+ } else {
+ Ok(expr)
+ }
+ } else {
+ Ok(expr)
+ }
+ }
+
+ Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+ // We only support comparisons for now
+ if !op.is_comparison_operator() {
+ return Ok(expr);
+ };
+
+ // Check if this is a comparison between a column and literal
+ let (col, op, value) = match (left.as_ref(), right.as_ref()) {
+ (Expr::Column(_), Expr::Literal(value)) => (left, *op,
value),
+ (Expr::Literal(value), Expr::Column(_)) => {
+ // If we can swap the op, we can simplify the
expression
+ if let Some(op) = op.swap() {
+ (right, op, value)
+ } else {
+ return Ok(expr);
+ }
+ }
+ _ => return Ok(expr),
+ };
+
+ if let Some(col_interval) = self.guarantees.get(col.as_ref()) {
+ let result =
+ col_interval.apply_operator(&op,
&value.clone().into())?;
+ if result.is_certainly_true() {
+ Ok(lit(true))
+ } else if result.is_certainly_false() {
+ Ok(lit(false))
+ } else {
+ Ok(expr)
+ }
+ } else {
+ Ok(expr)
+ }
+ }
+
+ // Columns (if interval is collapsed to a single value)
+ Expr::Column(_) => {
+ if let Some(col_interval) = self.guarantees.get(&expr) {
+ if let Some(value) = col_interval.single_value() {
+ Ok(lit(value))
+ } else {
+ Ok(expr)
+ }
+ } else {
+ Ok(expr)
+ }
+ }
+
+ Expr::InList(InList {
+ expr: inner,
+ list,
+ negated,
+ }) => {
+ if let Some(interval) = self.guarantees.get(inner.as_ref()) {
+ // Can remove items from the list that don't match the
guarantee
+ let new_list: Vec<Expr> = list
+ .iter()
+ .filter_map(|expr| {
+ if let Expr::Literal(item) = expr {
+ match interval
+
.contains(&NullableInterval::from(item.clone()))
+ {
+ // If we know for certain the value isn't
in the column's interval,
+ // we can skip checking it.
+ Ok(interval) if
interval.is_certainly_false() => None,
+ Ok(_) => Some(Ok(expr.clone())),
+ Err(e) => Some(Err(e)),
+ }
+ } else {
+ Some(Ok(expr.clone()))
+ }
+ })
+ .collect::<Result<_, DataFusionError>>()?;
+
+ Ok(Expr::InList(InList {
+ expr: inner.clone(),
+ list: new_list,
+ negated: *negated,
+ }))
+ } else {
+ Ok(expr)
+ }
+ }
+
+ _ => Ok(expr),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use arrow::datatypes::DataType;
+ use datafusion_common::{tree_node::TreeNode, ScalarValue};
+ use datafusion_expr::{col, lit, Operator};
+
+ #[test]
+ fn test_null_handling() {
+ // IsNull / IsNotNull can be rewritten to true / false
+ let guarantees = vec![
+ // Note: AlwaysNull case handled by test_column_single_value test,
+ // since it's a special case of a column with a single value.
+ (
+ col("x"),
+ NullableInterval::NotNull {
+ values: Default::default(),
+ },
+ ),
+ ];
+ let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+ // x IS NULL => guaranteed false
+ let expr = col("x").is_null();
+ let output = expr.clone().rewrite(&mut rewriter).unwrap();
+ assert_eq!(output, lit(false));
+
+ // x IS NOT NULL => guaranteed true
+ let expr = col("x").is_not_null();
+ let output = expr.clone().rewrite(&mut rewriter).unwrap();
+ assert_eq!(output, lit(true));
+ }
+
+ fn validate_simplified_cases<T>(rewriter: &mut GuaranteeRewriter, cases:
&[(Expr, T)])
+ where
+ ScalarValue: From<T>,
+ T: Clone,
+ {
+ for (expr, expected_value) in cases {
+ let output = expr.clone().rewrite(rewriter).unwrap();
+ let expected = lit(ScalarValue::from(expected_value.clone()));
+ assert_eq!(
+ output, expected,
+ "{} simplified to {}, but expected {}",
+ expr, output, expected
+ );
+ }
+ }
+
+ fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases:
&[Expr]) {
+ for expr in cases {
+ let output = expr.clone().rewrite(rewriter).unwrap();
+ assert_eq!(
+ &output, expr,
+ "{} was simplified to {}, but expected it to be unchanged",
+ expr, output
+ );
+ }
+ }
+
+ #[test]
+ fn test_inequalities_non_null_bounded() {
+ let guarantees = vec![
+ // x ∈ (1, 3] (not null)
+ (
+ col("x"),
+ NullableInterval::NotNull {
+ values: Interval::make(Some(1_i32), Some(3_i32), (true,
false)),
+ },
+ ),
+ ];
+
+ let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+ // (original_expr, expected_simplification)
+ let simplified_cases = &[
+ (col("x").lt_eq(lit(1)), false),
+ (col("x").lt_eq(lit(3)), true),
+ (col("x").gt(lit(3)), false),
+ (col("x").gt(lit(1)), true),
+ (col("x").eq(lit(0)), false),
+ (col("x").not_eq(lit(0)), true),
+ (col("x").between(lit(2), lit(5)), true),
+ (col("x").between(lit(2), lit(3)), true),
+ (col("x").between(lit(5), lit(10)), false),
+ (col("x").not_between(lit(2), lit(5)), false),
+ (col("x").not_between(lit(2), lit(3)), false),
+ (col("x").not_between(lit(5), lit(10)), true),
+ (
+ Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(col("x")),
+ op: Operator::IsDistinctFrom,
+ right: Box::new(lit(ScalarValue::Null)),
+ }),
+ true,
+ ),
+ (
+ Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(col("x")),
+ op: Operator::IsDistinctFrom,
+ right: Box::new(lit(5)),
+ }),
+ true,
+ ),
+ ];
+
+ validate_simplified_cases(&mut rewriter, simplified_cases);
+
+ let unchanged_cases = &[
+ col("x").gt(lit(2)),
+ col("x").lt_eq(lit(2)),
+ col("x").eq(lit(2)),
+ col("x").not_eq(lit(2)),
+ col("x").between(lit(3), lit(5)),
+ col("x").not_between(lit(3), lit(10)),
+ ];
+
+ validate_unchanged_cases(&mut rewriter, unchanged_cases);
+ }
+
+ #[test]
+ fn test_inequalities_non_null_unbounded() {
+ let guarantees = vec![
+ // y ∈ [2021-01-01, ∞) (not null)
+ (
+ col("x"),
+ NullableInterval::NotNull {
+ values: Interval::new(
+ IntervalBound::new(ScalarValue::Date32(Some(18628)),
false),
+
IntervalBound::make_unbounded(DataType::Date32).unwrap(),
+ ),
+ },
+ ),
+ ];
+ let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+ // (original_expr, expected_simplification)
+ let simplified_cases = &[
+ (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
+ (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
+ (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
+ (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
+ (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
+ (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
+ (
+ col("x").between(
+ lit(ScalarValue::Date32(Some(16000))),
+ lit(ScalarValue::Date32(Some(17000))),
+ ),
+ false,
+ ),
+ (
+ col("x").not_between(
+ lit(ScalarValue::Date32(Some(16000))),
+ lit(ScalarValue::Date32(Some(17000))),
+ ),
+ true,
+ ),
+ (
+ Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(col("x")),
+ op: Operator::IsDistinctFrom,
+ right: Box::new(lit(ScalarValue::Null)),
+ }),
+ true,
+ ),
+ (
+ Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(col("x")),
+ op: Operator::IsDistinctFrom,
+ right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
+ }),
+ true,
+ ),
+ ];
+
+ validate_simplified_cases(&mut rewriter, simplified_cases);
+
+ let unchanged_cases = &[
+ col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
+ col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
+ col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
+ col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
+ col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
+ col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
+ col("x").between(
+ lit(ScalarValue::Date32(Some(18000))),
+ lit(ScalarValue::Date32(Some(19000))),
+ ),
+ col("x").not_between(
+ lit(ScalarValue::Date32(Some(18000))),
+ lit(ScalarValue::Date32(Some(19000))),
+ ),
+ ];
+
+ validate_unchanged_cases(&mut rewriter, unchanged_cases);
+ }
+
+ #[test]
+ fn test_inequalities_maybe_null() {
+ let guarantees = vec![
+ // x ∈ ("abc", "def"]? (maybe null)
+ (
+ col("x"),
+ NullableInterval::MaybeNull {
+ values: Interval::make(Some("abc"), Some("def"), (true,
false)),
+ },
+ ),
+ ];
+ let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+ // (original_expr, expected_simplification)
+ let simplified_cases = &[
+ (
+ Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(col("x")),
+ op: Operator::IsDistinctFrom,
+ right: Box::new(lit("z")),
+ }),
+ true,
+ ),
+ (
+ Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(col("x")),
+ op: Operator::IsNotDistinctFrom,
+ right: Box::new(lit("z")),
+ }),
+ false,
+ ),
+ ];
+
+ validate_simplified_cases(&mut rewriter, simplified_cases);
+
+ let unchanged_cases = &[
+ col("x").lt(lit("z")),
+ col("x").lt_eq(lit("z")),
+ col("x").gt(lit("a")),
+ col("x").gt_eq(lit("a")),
+ col("x").eq(lit("abc")),
+ col("x").not_eq(lit("a")),
+ col("x").between(lit("a"), lit("z")),
+ col("x").not_between(lit("a"), lit("z")),
+ Expr::BinaryExpr(BinaryExpr {
+ left: Box::new(col("x")),
+ op: Operator::IsDistinctFrom,
+ right: Box::new(lit(ScalarValue::Null)),
+ }),
+ ];
+
+ validate_unchanged_cases(&mut rewriter, unchanged_cases);
+ }
+
+ #[test]
+ fn test_column_single_value() {
+ let scalars = [
+ ScalarValue::Null,
+ ScalarValue::Int32(Some(1)),
+ ScalarValue::Boolean(Some(true)),
+ ScalarValue::Boolean(None),
+ ScalarValue::Utf8(Some("abc".to_string())),
+ ScalarValue::LargeUtf8(Some("def".to_string())),
+ ScalarValue::Date32(Some(18628)),
+ ScalarValue::Date32(None),
+ ScalarValue::Decimal128(Some(1000), 19, 2),
+ ];
+
+ for scalar in scalars {
+ let guarantees = vec![(col("x"),
NullableInterval::from(scalar.clone()))];
+ let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+ let output = col("x").rewrite(&mut rewriter).unwrap();
+ assert_eq!(output, Expr::Literal(scalar.clone()));
+ }
+ }
+
+ #[test]
+ fn test_in_list() {
+ let guarantees = vec![
+ // x ∈ [1, 10) (not null)
+ (
+ col("x"),
+ NullableInterval::NotNull {
+ values: Interval::make(Some(1_i32), Some(10_i32), (false,
true)),
+ },
+ ),
+ ];
+ let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+ // These cases should be simplified so the list doesn't contain any
+ // values the guarantee says are outside the range.
+ // (column_name, starting_list, negated, expected_list)
+ let cases = &[
+ // x IN (9, 11) => x IN (9)
+ ("x", vec![9, 11], false, vec![9]),
+ // x IN (10, 2) => x IN (2)
+ ("x", vec![10, 2], false, vec![2]),
+ // x NOT IN (9, 11) => x NOT IN (9)
+ ("x", vec![9, 11], true, vec![9]),
+ // x NOT IN (0, 22) => x NOT IN ()
+ ("x", vec![0, 22], true, vec![]),
+ ];
+
+ for (column_name, starting_list, negated, expected_list) in cases {
+ let expr = col(*column_name).in_list(
+ starting_list
+ .iter()
+ .map(|v| lit(ScalarValue::Int32(Some(*v))))
+ .collect(),
+ *negated,
+ );
+ let output = expr.clone().rewrite(&mut rewriter).unwrap();
+ let expected_list = expected_list
+ .iter()
+ .map(|v| lit(ScalarValue::Int32(Some(*v))))
+ .collect();
+ assert_eq!(
+ output,
+ Expr::InList(InList {
+ expr: Box::new(col(*column_name)),
+ list: expected_list,
+ negated: *negated,
+ })
+ );
+ }
+ }
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs
b/datafusion/optimizer/src/simplify_expressions/mod.rs
index dfa0fe7043..2cf6ed166c 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -17,6 +17,7 @@
pub mod context;
pub mod expr_simplifier;
+mod guarantees;
mod or_in_list_simplifier;
mod regex;
pub mod simplify_exprs;
diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
index 3f72ef588c..5501c8cae0 100644
--- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
+++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
@@ -396,6 +396,22 @@ impl Interval {
}
}
+ /// Compute the logical negation of this (boolean) interval.
+ pub(crate) fn not(&self) -> Result<Self> {
+ if !matches!(self.get_datatype()?, DataType::Boolean) {
+ return internal_err!(
+ "Cannot apply logical negation to non-boolean interval"
+ );
+ }
+ if self == &Interval::CERTAINLY_TRUE {
+ Ok(Interval::CERTAINLY_FALSE)
+ } else if self == &Interval::CERTAINLY_FALSE {
+ Ok(Interval::CERTAINLY_TRUE)
+ } else {
+ Ok(Interval::UNCERTAIN)
+ }
+ }
+
/// Compute the intersection of the interval with the given interval.
/// If the intersection is empty, return None.
pub(crate) fn intersect<T: Borrow<Interval>>(
@@ -426,6 +442,23 @@ impl Interval {
Ok(non_empty.then_some(Interval::new(lower, upper)))
}
+ /// Decide if this interval is certainly contains, possibly contains,
+ /// or can't can't `other` by returning [true, true],
+ /// [false, true] or [false, false] respectively.
+ pub fn contains<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
+ match self.intersect(other.borrow())? {
+ Some(intersection) => {
+ // Need to compare with same bounds close-ness.
+ if intersection.close_bounds() ==
other.borrow().clone().close_bounds() {
+ Ok(Interval::CERTAINLY_TRUE)
+ } else {
+ Ok(Interval::UNCERTAIN)
+ }
+ }
+ None => Ok(Interval::CERTAINLY_FALSE),
+ }
+ }
+
/// Add the given interval (`other`) to this interval. Say we have
/// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2].
/// Note that this represents all possible values the sum can take if
@@ -633,6 +666,7 @@ pub fn cardinality_ratio(
pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) ->
Result<Interval> {
match *op {
Operator::Eq => Ok(lhs.equal(rhs)),
+ Operator::NotEq => Ok(lhs.equal(rhs).not()?),
Operator::Gt => Ok(lhs.gt(rhs)),
Operator::GtEq => Ok(lhs.gt_eq(rhs)),
Operator::Lt => Ok(lhs.lt(rhs)),
@@ -667,6 +701,283 @@ fn calculate_cardinality_based_on_bounds(
}
}
+/// An [Interval] that also tracks null status using a boolean interval.
+///
+/// This represents values that may be in a particular range or be null.
+///
+/// # Examples
+///
+/// ```
+/// use arrow::datatypes::DataType;
+/// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
+/// use datafusion_common::ScalarValue;
+///
+/// // [1, 2) U {NULL}
+/// NullableInterval::MaybeNull {
+/// values: Interval::make(Some(1), Some(2), (false, true)),
+/// };
+///
+/// // (0, ∞)
+/// NullableInterval::NotNull {
+/// values: Interval::make(Some(0), None, (true, true)),
+/// };
+///
+/// // {NULL}
+/// NullableInterval::Null { datatype: DataType::Int32 };
+///
+/// // {4}
+/// NullableInterval::from(ScalarValue::Int32(Some(4)));
+/// ```
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum NullableInterval {
+ /// The value is always null in this interval
+ ///
+ /// This is typed so it can be used in physical expressions, which don't do
+ /// type coercion.
+ Null { datatype: DataType },
+ /// The value may or may not be null in this interval. If it is non null
its value is within
+ /// the specified values interval
+ MaybeNull { values: Interval },
+ /// The value is definitely not null in this interval and is within values
+ NotNull { values: Interval },
+}
+
+impl Default for NullableInterval {
+ fn default() -> Self {
+ NullableInterval::MaybeNull {
+ values: Interval::default(),
+ }
+ }
+}
+
+impl Display for NullableInterval {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"),
+ Self::MaybeNull { values } => {
+ write!(f, "NullableInterval: {} U {{NULL}}", values)
+ }
+ Self::NotNull { values } => write!(f, "NullableInterval: {}",
values),
+ }
+ }
+}
+
+impl From<ScalarValue> for NullableInterval {
+ /// Create an interval that represents a single value.
+ fn from(value: ScalarValue) -> Self {
+ if value.is_null() {
+ Self::Null {
+ datatype: value.data_type(),
+ }
+ } else {
+ Self::NotNull {
+ values: Interval::new(
+ IntervalBound::new(value.clone(), false),
+ IntervalBound::new(value, false),
+ ),
+ }
+ }
+ }
+}
+
+impl NullableInterval {
+ /// Get the values interval, or None if this interval is definitely null.
+ pub fn values(&self) -> Option<&Interval> {
+ match self {
+ Self::Null { .. } => None,
+ Self::MaybeNull { values } | Self::NotNull { values } =>
Some(values),
+ }
+ }
+
+ /// Get the data type
+ pub fn get_datatype(&self) -> Result<DataType> {
+ match self {
+ Self::Null { datatype } => Ok(datatype.clone()),
+ Self::MaybeNull { values } | Self::NotNull { values } => {
+ values.get_datatype()
+ }
+ }
+ }
+
+ /// Return true if the value is definitely true (and not null).
+ pub fn is_certainly_true(&self) -> bool {
+ match self {
+ Self::Null { .. } | Self::MaybeNull { .. } => false,
+ Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE,
+ }
+ }
+
+ /// Return true if the value is definitely false (and not null).
+ pub fn is_certainly_false(&self) -> bool {
+ match self {
+ Self::Null { .. } => false,
+ Self::MaybeNull { .. } => false,
+ Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE,
+ }
+ }
+
+ /// Perform logical negation on a boolean nullable interval.
+ fn not(&self) -> Result<Self> {
+ match self {
+ Self::Null { datatype } => Ok(Self::Null {
+ datatype: datatype.clone(),
+ }),
+ Self::MaybeNull { values } => Ok(Self::MaybeNull {
+ values: values.not()?,
+ }),
+ Self::NotNull { values } => Ok(Self::NotNull {
+ values: values.not()?,
+ }),
+ }
+ }
+
+ /// Apply the given operator to this interval and the given interval.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use datafusion_common::ScalarValue;
+ /// use datafusion_expr::Operator;
+ /// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
+ ///
+ /// // 4 > 3 -> true
+ /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4)));
+ /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3)));
+ /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap();
+ /// assert_eq!(result,
NullableInterval::from(ScalarValue::Boolean(Some(true))));
+ ///
+ /// // [1, 3) > NULL -> NULL
+ /// let lhs = NullableInterval::NotNull {
+ /// values: Interval::make(Some(1), Some(3), (false, true)),
+ /// };
+ /// let rhs = NullableInterval::from(ScalarValue::Int32(None));
+ /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap();
+ /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None)));
+ ///
+ /// // [1, 3] > [2, 4] -> [false, true]
+ /// let lhs = NullableInterval::NotNull {
+ /// values: Interval::make(Some(1), Some(3), (false, false)),
+ /// };
+ /// let rhs = NullableInterval::NotNull {
+ /// values: Interval::make(Some(2), Some(4), (false, false)),
+ /// };
+ /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap();
+ /// // Both inputs are valid (non-null), so result must be non-null
+ /// assert_eq!(result, NullableInterval::NotNull {
+ /// // Uncertain whether inequality is true or false
+ /// values: Interval::UNCERTAIN,
+ /// });
+ ///
+ /// ```
+ pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result<Self> {
+ match op {
+ Operator::IsDistinctFrom => {
+ let values = match (self, rhs) {
+ // NULL is distinct from NULL -> False
+ (Self::Null { .. }, Self::Null { .. }) =>
Interval::CERTAINLY_FALSE,
+ // x is distinct from y -> x != y,
+ // if at least one of them is never null.
+ (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => {
+ let lhs_values = self.values();
+ let rhs_values = rhs.values();
+ match (lhs_values, rhs_values) {
+ (Some(lhs_values), Some(rhs_values)) => {
+ lhs_values.equal(rhs_values).not()?
+ }
+ (Some(_), None) | (None, Some(_)) =>
Interval::CERTAINLY_TRUE,
+ (None, None) => unreachable!("Null case handled
above"),
+ }
+ }
+ _ => Interval::UNCERTAIN,
+ };
+ // IsDistinctFrom never returns null.
+ Ok(Self::NotNull { values })
+ }
+ Operator::IsNotDistinctFrom => self
+ .apply_operator(&Operator::IsDistinctFrom, rhs)
+ .map(|i| i.not())?,
+ _ => {
+ if let (Some(left_values), Some(right_values)) =
+ (self.values(), rhs.values())
+ {
+ let values = apply_operator(op, left_values,
right_values)?;
+ match (self, rhs) {
+ (Self::NotNull { .. }, Self::NotNull { .. }) => {
+ Ok(Self::NotNull { values })
+ }
+ _ => Ok(Self::MaybeNull { values }),
+ }
+ } else if op.is_comparison_operator() {
+ Ok(Self::Null {
+ datatype: DataType::Boolean,
+ })
+ } else {
+ Ok(Self::Null {
+ datatype: self.get_datatype()?,
+ })
+ }
+ }
+ }
+ }
+
+ /// Determine if this interval contains the given interval. Returns a
boolean
+ /// interval that is [true, true] if this interval is a superset of the
+ /// given interval, [false, false] if this interval is disjoint from the
+ /// given interval, and [false, true] otherwise.
+ pub fn contains<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
+ let rhs = other.borrow();
+ if let (Some(left_values), Some(right_values)) = (self.values(),
rhs.values()) {
+ let values = left_values.contains(right_values)?;
+ match (self, rhs) {
+ (Self::NotNull { .. }, Self::NotNull { .. }) => {
+ Ok(Self::NotNull { values })
+ }
+ _ => Ok(Self::MaybeNull { values }),
+ }
+ } else {
+ Ok(Self::Null {
+ datatype: DataType::Boolean,
+ })
+ }
+ }
+
+ /// If the interval has collapsed to a single value, return that value.
+ ///
+ /// Otherwise returns None.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// use datafusion_common::ScalarValue;
+ /// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
+ ///
+ /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4)));
+ /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4))));
+ ///
+ /// let interval = NullableInterval::from(ScalarValue::Int32(None));
+ /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None)));
+ ///
+ /// let interval = NullableInterval::MaybeNull {
+ /// values: Interval::make(Some(1), Some(4), (false, true)),
+ /// };
+ /// assert_eq!(interval.single_value(), None);
+ /// ```
+ pub fn single_value(&self) -> Option<ScalarValue> {
+ match self {
+ Self::Null { datatype } => {
+
Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null))
+ }
+ Self::MaybeNull { values } | Self::NotNull { values }
+ if values.lower.value == values.upper.value
+ && !values.lower.is_unbounded() =>
+ {
+ Some(values.lower.value.clone())
+ }
+ _ => None,
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::next_value;