This is an automated email from the ASF dual-hosted git repository.

wjones127 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8946f8bd34 feat: add guarantees to simplification (#7467)
8946f8bd34 is described below

commit 8946f8bd34e3d29009b5cbe41da8ef4d04af421a
Author: Will Jones <[email protected]>
AuthorDate: Wed Sep 13 14:09:51 2023 -0400

    feat: add guarantees to simplification (#7467)
    
    * feat: add guarantees to simplifcation
    
    * null and comparison support
    
    * add support for literal expressions
    
    * implement inlist guarantee use
    
    * test the outer function
    
    * docs
    
    * refactor to use intervals
    
    * add high-level test
    
    * cleanup
    
    * fix test to be false or null, not true
    
    * refactor: change NullableInterval to an enum
    
    * refactor: use a builder-like API
    
    * pr feedback
    
    * Fix clippy
    
    * fix doc links
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../src/simplify_expressions/expr_simplifier.rs    | 177 ++++++-
 .../src/simplify_expressions/guarantees.rs         | 520 +++++++++++++++++++++
 .../optimizer/src/simplify_expressions/mod.rs      |   1 +
 .../src/intervals/interval_aritmetic.rs            | 311 ++++++++++++
 4 files changed, 1006 insertions(+), 3 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index c92660c7bb..f5a6860299 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -39,13 +39,20 @@ use datafusion_expr::{
     and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, 
ColumnarValue, Expr,
     Like, Volatility,
 };
-use datafusion_physical_expr::{create_physical_expr, 
execution_props::ExecutionProps};
+use datafusion_physical_expr::{
+    create_physical_expr, execution_props::ExecutionProps, 
intervals::NullableInterval,
+};
 
 use crate::simplify_expressions::SimplifyInfo;
 
+use crate::simplify_expressions::guarantees::GuaranteeRewriter;
+
 /// This structure handles API for expression simplification
 pub struct ExprSimplifier<S> {
     info: S,
+    /// Guarantees about the values of columns. This is provided by the user
+    /// in [ExprSimplifier::with_guarantees()].
+    guarantees: Vec<(Expr, NullableInterval)>,
 }
 
 pub const THRESHOLD_INLINE_INLIST: usize = 3;
@@ -57,7 +64,10 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
     ///
     /// [`SimplifyContext`]: 
crate::simplify_expressions::context::SimplifyContext
     pub fn new(info: S) -> Self {
-        Self { info }
+        Self {
+            info,
+            guarantees: vec![],
+        }
     }
 
     /// Simplifies this [`Expr`]`s as much as possible, evaluating
@@ -121,6 +131,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
         let mut simplifier = Simplifier::new(&self.info);
         let mut const_evaluator = 
ConstEvaluator::try_new(self.info.execution_props())?;
         let mut or_in_list_simplifier = OrInListSimplifier::new();
+        let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees);
 
         // TODO iterate until no changes are made during rewrite
         // (evaluating constants can enable new simplifications and
@@ -129,6 +140,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
         expr.rewrite(&mut const_evaluator)?
             .rewrite(&mut simplifier)?
             .rewrite(&mut or_in_list_simplifier)?
+            .rewrite(&mut guarantee_rewriter)?
             // run both passes twice to try an minimize simplifications that 
we missed
             .rewrite(&mut const_evaluator)?
             .rewrite(&mut simplifier)
@@ -149,6 +161,65 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
 
         expr.rewrite(&mut expr_rewrite)
     }
+
+    /// Input guarantees about the values of columns.
+    ///
+    /// The guarantees can simplify expressions. For example, if a column `x` 
is
+    /// guaranteed to be `3`, then the expression `x > 1` can be replaced by 
the
+    /// literal `true`.
+    ///
+    /// The guarantees are provided as a `Vec<(Expr, NullableInterval)>`,
+    /// where the [Expr] is a column reference and the [NullableInterval]
+    /// is an interval representing the known possible values of that column.
+    ///
+    /// ```rust
+    /// use arrow::datatypes::{DataType, Field, Schema};
+    /// use datafusion_expr::{col, lit, Expr};
+    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
+    /// use datafusion_physical_expr::execution_props::ExecutionProps;
+    /// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
+    /// use datafusion_optimizer::simplify_expressions::{
+    ///     ExprSimplifier, SimplifyContext};
+    ///
+    /// let schema = Schema::new(vec![
+    ///   Field::new("x", DataType::Int64, false),
+    ///   Field::new("y", DataType::UInt32, false),
+    ///   Field::new("z", DataType::Int64, false),
+    ///   ])
+    ///   .to_dfschema_ref().unwrap();
+    ///
+    /// // Create the simplifier
+    /// let props = ExecutionProps::new();
+    /// let context = SimplifyContext::new(&props)
+    ///    .with_schema(schema);
+    ///
+    /// // Expression: (x >= 3) AND (y + 2 < 10) AND (z > 5)
+    /// let expr_x = col("x").gt_eq(lit(3_i64));
+    /// let expr_y = (col("y") + lit(2_u32)).lt(lit(10_u32));
+    /// let expr_z = col("z").gt(lit(5_i64));
+    /// let expr = expr_x.and(expr_y).and(expr_z.clone());
+    ///
+    /// let guarantees = vec![
+    ///    // x ∈ [3, 5]
+    ///    (
+    ///        col("x"),
+    ///        NullableInterval::NotNull {
+    ///            values: Interval::make(Some(3_i64), Some(5_i64), (false, 
false)),
+    ///        }
+    ///    ),
+    ///    // y = 3
+    ///    (col("y"), NullableInterval::from(ScalarValue::UInt32(Some(3)))),
+    /// ];
+    /// let simplifier = 
ExprSimplifier::new(context).with_guarantees(guarantees);
+    /// let output = simplifier.simplify(expr).unwrap();
+    /// // Expression becomes: true AND true AND (z > 5), which simplifies to
+    /// // z > 5.
+    /// assert_eq!(output, expr_z);
+    /// ```
+    pub fn with_guarantees(mut self, guarantees: Vec<(Expr, 
NullableInterval)>) -> Self {
+        self.guarantees = guarantees;
+        self
+    }
 }
 
 #[allow(rustdoc::private_intra_doc_links)]
@@ -1239,7 +1310,9 @@ mod tests {
     use datafusion_common::{assert_contains, cast::as_int32_array, DFField, 
ToDFSchema};
     use datafusion_expr::*;
     use datafusion_physical_expr::{
-        execution_props::ExecutionProps, functions::make_scalar_function,
+        execution_props::ExecutionProps,
+        functions::make_scalar_function,
+        intervals::{Interval, NullableInterval},
     };
 
     // ------------------------------
@@ -2703,6 +2776,19 @@ mod tests {
         try_simplify(expr).unwrap()
     }
 
+    fn simplify_with_guarantee(
+        expr: Expr,
+        guarantees: Vec<(Expr, NullableInterval)>,
+    ) -> Expr {
+        let schema = expr_test_schema();
+        let execution_props = ExecutionProps::new();
+        let simplifier = ExprSimplifier::new(
+            SimplifyContext::new(&execution_props).with_schema(schema),
+        )
+        .with_guarantees(guarantees);
+        simplifier.simplify(expr).unwrap()
+    }
+
     fn expr_test_schema() -> DFSchemaRef {
         Arc::new(
             DFSchema::new_with_metadata(
@@ -3166,4 +3252,89 @@ mod tests {
         let expr = not_ilike(null, "%");
         assert_eq!(simplify(expr), lit_bool_null());
     }
+
+    #[test]
+    fn test_simplify_with_guarantee() {
+        // (c3 >= 3) AND (c4 + 2 < 10 OR (c1 NOT IN ("a", "b")))
+        let expr_x = col("c3").gt(lit(3_i64));
+        let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32));
+        let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true);
+        let expr = expr_x.clone().and(expr_y.clone().or(expr_z));
+
+        // All guaranteed null
+        let guarantees = vec![
+            (col("c3"), NullableInterval::from(ScalarValue::Int64(None))),
+            (col("c4"), NullableInterval::from(ScalarValue::UInt32(None))),
+            (col("c1"), NullableInterval::from(ScalarValue::Utf8(None))),
+        ];
+
+        let output = simplify_with_guarantee(expr.clone(), guarantees);
+        assert_eq!(output, lit_bool_null());
+
+        // All guaranteed false
+        let guarantees = vec![
+            (
+                col("c3"),
+                NullableInterval::NotNull {
+                    values: Interval::make(Some(0_i64), Some(2_i64), (false, 
false)),
+                },
+            ),
+            (
+                col("c4"),
+                NullableInterval::from(ScalarValue::UInt32(Some(9))),
+            ),
+            (
+                col("c1"),
+                
NullableInterval::from(ScalarValue::Utf8(Some("a".to_string()))),
+            ),
+        ];
+        let output = simplify_with_guarantee(expr.clone(), guarantees);
+        assert_eq!(output, lit(false));
+
+        // Guaranteed false or null -> no change.
+        let guarantees = vec![
+            (
+                col("c3"),
+                NullableInterval::MaybeNull {
+                    values: Interval::make(Some(0_i64), Some(2_i64), (false, 
false)),
+                },
+            ),
+            (
+                col("c4"),
+                NullableInterval::MaybeNull {
+                    values: Interval::make(Some(9_u32), Some(9_u32), (false, 
false)),
+                },
+            ),
+            (
+                col("c1"),
+                NullableInterval::NotNull {
+                    values: Interval::make(Some("d"), Some("f"), (false, 
false)),
+                },
+            ),
+        ];
+        let output = simplify_with_guarantee(expr.clone(), guarantees);
+        assert_eq!(&output, &expr_x);
+
+        // Sufficient true guarantees
+        let guarantees = vec![
+            (
+                col("c3"),
+                NullableInterval::from(ScalarValue::Int64(Some(9))),
+            ),
+            (
+                col("c4"),
+                NullableInterval::from(ScalarValue::UInt32(Some(3))),
+            ),
+        ];
+        let output = simplify_with_guarantee(expr.clone(), guarantees);
+        assert_eq!(output, lit(true));
+
+        // Only partially simplify
+        let guarantees = vec![(
+            col("c4"),
+            NullableInterval::from(ScalarValue::UInt32(Some(3))),
+        )];
+        let output = simplify_with_guarantee(expr.clone(), guarantees);
+        assert_eq!(&output, &expr_x);
+    }
 }
diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs 
b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
new file mode 100644
index 0000000000..5504d7d76e
--- /dev/null
+++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
@@ -0,0 +1,520 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Simplifier implementation for [`ExprSimplifier::with_guarantees()`]
+//!
+//! [`ExprSimplifier::with_guarantees()`]: 
crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
+use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result};
+use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr};
+use std::collections::HashMap;
+
+use datafusion_physical_expr::intervals::{Interval, IntervalBound, 
NullableInterval};
+
+/// Rewrite expressions to incorporate guarantees.
+///
+/// Guarantees are a mapping from an expression (which currently is always a
+/// column reference) to a [NullableInterval]. The interval represents the 
known
+/// possible values of the column. Using these known values, expressions are
+/// rewritten so they can be simplified using `ConstEvaluator` and 
`Simplifier`.
+///
+/// For example, if we know that a column is not null and has values in the
+/// range [1, 10), we can rewrite `x IS NULL` to `false` or `x < 10` to `true`.
+///
+/// See a full example in [`ExprSimplifier::with_guarantees()`].
+///
+/// [`ExprSimplifier::with_guarantees()`]: 
crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees
+pub(crate) struct GuaranteeRewriter<'a> {
+    guarantees: HashMap<&'a Expr, &'a NullableInterval>,
+}
+
+impl<'a> GuaranteeRewriter<'a> {
+    pub fn new(
+        guarantees: impl IntoIterator<Item = &'a (Expr, NullableInterval)>,
+    ) -> Self {
+        Self {
+            guarantees: guarantees.into_iter().map(|(k, v)| (k, v)).collect(),
+        }
+    }
+}
+
+impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
+    type N = Expr;
+
+    fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+        if self.guarantees.is_empty() {
+            return Ok(expr);
+        }
+
+        match &expr {
+            Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) {
+                Some(NullableInterval::Null { .. }) => Ok(lit(true)),
+                Some(NullableInterval::NotNull { .. }) => Ok(lit(false)),
+                _ => Ok(expr),
+            },
+            Expr::IsNotNull(inner) => match 
self.guarantees.get(inner.as_ref()) {
+                Some(NullableInterval::Null { .. }) => Ok(lit(false)),
+                Some(NullableInterval::NotNull { .. }) => Ok(lit(true)),
+                _ => Ok(expr),
+            },
+            Expr::Between(Between {
+                expr: inner,
+                negated,
+                low,
+                high,
+            }) => {
+                if let (Some(interval), Expr::Literal(low), 
Expr::Literal(high)) = (
+                    self.guarantees.get(inner.as_ref()),
+                    low.as_ref(),
+                    high.as_ref(),
+                ) {
+                    let expr_interval = NullableInterval::NotNull {
+                        values: Interval::new(
+                            IntervalBound::new(low.clone(), false),
+                            IntervalBound::new(high.clone(), false),
+                        ),
+                    };
+
+                    let contains = expr_interval.contains(*interval)?;
+
+                    if contains.is_certainly_true() {
+                        Ok(lit(!negated))
+                    } else if contains.is_certainly_false() {
+                        Ok(lit(*negated))
+                    } else {
+                        Ok(expr)
+                    }
+                } else {
+                    Ok(expr)
+                }
+            }
+
+            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+                // We only support comparisons for now
+                if !op.is_comparison_operator() {
+                    return Ok(expr);
+                };
+
+                // Check if this is a comparison between a column and literal
+                let (col, op, value) = match (left.as_ref(), right.as_ref()) {
+                    (Expr::Column(_), Expr::Literal(value)) => (left, *op, 
value),
+                    (Expr::Literal(value), Expr::Column(_)) => {
+                        // If we can swap the op, we can simplify the 
expression
+                        if let Some(op) = op.swap() {
+                            (right, op, value)
+                        } else {
+                            return Ok(expr);
+                        }
+                    }
+                    _ => return Ok(expr),
+                };
+
+                if let Some(col_interval) = self.guarantees.get(col.as_ref()) {
+                    let result =
+                        col_interval.apply_operator(&op, 
&value.clone().into())?;
+                    if result.is_certainly_true() {
+                        Ok(lit(true))
+                    } else if result.is_certainly_false() {
+                        Ok(lit(false))
+                    } else {
+                        Ok(expr)
+                    }
+                } else {
+                    Ok(expr)
+                }
+            }
+
+            // Columns (if interval is collapsed to a single value)
+            Expr::Column(_) => {
+                if let Some(col_interval) = self.guarantees.get(&expr) {
+                    if let Some(value) = col_interval.single_value() {
+                        Ok(lit(value))
+                    } else {
+                        Ok(expr)
+                    }
+                } else {
+                    Ok(expr)
+                }
+            }
+
+            Expr::InList(InList {
+                expr: inner,
+                list,
+                negated,
+            }) => {
+                if let Some(interval) = self.guarantees.get(inner.as_ref()) {
+                    // Can remove items from the list that don't match the 
guarantee
+                    let new_list: Vec<Expr> = list
+                        .iter()
+                        .filter_map(|expr| {
+                            if let Expr::Literal(item) = expr {
+                                match interval
+                                    
.contains(&NullableInterval::from(item.clone()))
+                                {
+                                    // If we know for certain the value isn't 
in the column's interval,
+                                    // we can skip checking it.
+                                    Ok(interval) if 
interval.is_certainly_false() => None,
+                                    Ok(_) => Some(Ok(expr.clone())),
+                                    Err(e) => Some(Err(e)),
+                                }
+                            } else {
+                                Some(Ok(expr.clone()))
+                            }
+                        })
+                        .collect::<Result<_, DataFusionError>>()?;
+
+                    Ok(Expr::InList(InList {
+                        expr: inner.clone(),
+                        list: new_list,
+                        negated: *negated,
+                    }))
+                } else {
+                    Ok(expr)
+                }
+            }
+
+            _ => Ok(expr),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    use arrow::datatypes::DataType;
+    use datafusion_common::{tree_node::TreeNode, ScalarValue};
+    use datafusion_expr::{col, lit, Operator};
+
+    #[test]
+    fn test_null_handling() {
+        // IsNull / IsNotNull can be rewritten to true / false
+        let guarantees = vec![
+            // Note: AlwaysNull case handled by test_column_single_value test,
+            // since it's a special case of a column with a single value.
+            (
+                col("x"),
+                NullableInterval::NotNull {
+                    values: Default::default(),
+                },
+            ),
+        ];
+        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+        // x IS NULL => guaranteed false
+        let expr = col("x").is_null();
+        let output = expr.clone().rewrite(&mut rewriter).unwrap();
+        assert_eq!(output, lit(false));
+
+        // x IS NOT NULL => guaranteed true
+        let expr = col("x").is_not_null();
+        let output = expr.clone().rewrite(&mut rewriter).unwrap();
+        assert_eq!(output, lit(true));
+    }
+
+    fn validate_simplified_cases<T>(rewriter: &mut GuaranteeRewriter, cases: 
&[(Expr, T)])
+    where
+        ScalarValue: From<T>,
+        T: Clone,
+    {
+        for (expr, expected_value) in cases {
+            let output = expr.clone().rewrite(rewriter).unwrap();
+            let expected = lit(ScalarValue::from(expected_value.clone()));
+            assert_eq!(
+                output, expected,
+                "{} simplified to {}, but expected {}",
+                expr, output, expected
+            );
+        }
+    }
+
+    fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: 
&[Expr]) {
+        for expr in cases {
+            let output = expr.clone().rewrite(rewriter).unwrap();
+            assert_eq!(
+                &output, expr,
+                "{} was simplified to {}, but expected it to be unchanged",
+                expr, output
+            );
+        }
+    }
+
+    #[test]
+    fn test_inequalities_non_null_bounded() {
+        let guarantees = vec![
+            // x ∈ (1, 3] (not null)
+            (
+                col("x"),
+                NullableInterval::NotNull {
+                    values: Interval::make(Some(1_i32), Some(3_i32), (true, 
false)),
+                },
+            ),
+        ];
+
+        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+        // (original_expr, expected_simplification)
+        let simplified_cases = &[
+            (col("x").lt_eq(lit(1)), false),
+            (col("x").lt_eq(lit(3)), true),
+            (col("x").gt(lit(3)), false),
+            (col("x").gt(lit(1)), true),
+            (col("x").eq(lit(0)), false),
+            (col("x").not_eq(lit(0)), true),
+            (col("x").between(lit(2), lit(5)), true),
+            (col("x").between(lit(2), lit(3)), true),
+            (col("x").between(lit(5), lit(10)), false),
+            (col("x").not_between(lit(2), lit(5)), false),
+            (col("x").not_between(lit(2), lit(3)), false),
+            (col("x").not_between(lit(5), lit(10)), true),
+            (
+                Expr::BinaryExpr(BinaryExpr {
+                    left: Box::new(col("x")),
+                    op: Operator::IsDistinctFrom,
+                    right: Box::new(lit(ScalarValue::Null)),
+                }),
+                true,
+            ),
+            (
+                Expr::BinaryExpr(BinaryExpr {
+                    left: Box::new(col("x")),
+                    op: Operator::IsDistinctFrom,
+                    right: Box::new(lit(5)),
+                }),
+                true,
+            ),
+        ];
+
+        validate_simplified_cases(&mut rewriter, simplified_cases);
+
+        let unchanged_cases = &[
+            col("x").gt(lit(2)),
+            col("x").lt_eq(lit(2)),
+            col("x").eq(lit(2)),
+            col("x").not_eq(lit(2)),
+            col("x").between(lit(3), lit(5)),
+            col("x").not_between(lit(3), lit(10)),
+        ];
+
+        validate_unchanged_cases(&mut rewriter, unchanged_cases);
+    }
+
+    #[test]
+    fn test_inequalities_non_null_unbounded() {
+        let guarantees = vec![
+            // y ∈ [2021-01-01, ∞) (not null)
+            (
+                col("x"),
+                NullableInterval::NotNull {
+                    values: Interval::new(
+                        IntervalBound::new(ScalarValue::Date32(Some(18628)), 
false),
+                        
IntervalBound::make_unbounded(DataType::Date32).unwrap(),
+                    ),
+                },
+            ),
+        ];
+        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+        // (original_expr, expected_simplification)
+        let simplified_cases = &[
+            (col("x").lt(lit(ScalarValue::Date32(Some(18628)))), false),
+            (col("x").lt_eq(lit(ScalarValue::Date32(Some(17000)))), false),
+            (col("x").gt(lit(ScalarValue::Date32(Some(18627)))), true),
+            (col("x").gt_eq(lit(ScalarValue::Date32(Some(18628)))), true),
+            (col("x").eq(lit(ScalarValue::Date32(Some(17000)))), false),
+            (col("x").not_eq(lit(ScalarValue::Date32(Some(17000)))), true),
+            (
+                col("x").between(
+                    lit(ScalarValue::Date32(Some(16000))),
+                    lit(ScalarValue::Date32(Some(17000))),
+                ),
+                false,
+            ),
+            (
+                col("x").not_between(
+                    lit(ScalarValue::Date32(Some(16000))),
+                    lit(ScalarValue::Date32(Some(17000))),
+                ),
+                true,
+            ),
+            (
+                Expr::BinaryExpr(BinaryExpr {
+                    left: Box::new(col("x")),
+                    op: Operator::IsDistinctFrom,
+                    right: Box::new(lit(ScalarValue::Null)),
+                }),
+                true,
+            ),
+            (
+                Expr::BinaryExpr(BinaryExpr {
+                    left: Box::new(col("x")),
+                    op: Operator::IsDistinctFrom,
+                    right: Box::new(lit(ScalarValue::Date32(Some(17000)))),
+                }),
+                true,
+            ),
+        ];
+
+        validate_simplified_cases(&mut rewriter, simplified_cases);
+
+        let unchanged_cases = &[
+            col("x").lt(lit(ScalarValue::Date32(Some(19000)))),
+            col("x").lt_eq(lit(ScalarValue::Date32(Some(19000)))),
+            col("x").gt(lit(ScalarValue::Date32(Some(19000)))),
+            col("x").gt_eq(lit(ScalarValue::Date32(Some(19000)))),
+            col("x").eq(lit(ScalarValue::Date32(Some(19000)))),
+            col("x").not_eq(lit(ScalarValue::Date32(Some(19000)))),
+            col("x").between(
+                lit(ScalarValue::Date32(Some(18000))),
+                lit(ScalarValue::Date32(Some(19000))),
+            ),
+            col("x").not_between(
+                lit(ScalarValue::Date32(Some(18000))),
+                lit(ScalarValue::Date32(Some(19000))),
+            ),
+        ];
+
+        validate_unchanged_cases(&mut rewriter, unchanged_cases);
+    }
+
+    #[test]
+    fn test_inequalities_maybe_null() {
+        let guarantees = vec![
+            // x ∈ ("abc", "def"]? (maybe null)
+            (
+                col("x"),
+                NullableInterval::MaybeNull {
+                    values: Interval::make(Some("abc"), Some("def"), (true, 
false)),
+                },
+            ),
+        ];
+        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+        // (original_expr, expected_simplification)
+        let simplified_cases = &[
+            (
+                Expr::BinaryExpr(BinaryExpr {
+                    left: Box::new(col("x")),
+                    op: Operator::IsDistinctFrom,
+                    right: Box::new(lit("z")),
+                }),
+                true,
+            ),
+            (
+                Expr::BinaryExpr(BinaryExpr {
+                    left: Box::new(col("x")),
+                    op: Operator::IsNotDistinctFrom,
+                    right: Box::new(lit("z")),
+                }),
+                false,
+            ),
+        ];
+
+        validate_simplified_cases(&mut rewriter, simplified_cases);
+
+        let unchanged_cases = &[
+            col("x").lt(lit("z")),
+            col("x").lt_eq(lit("z")),
+            col("x").gt(lit("a")),
+            col("x").gt_eq(lit("a")),
+            col("x").eq(lit("abc")),
+            col("x").not_eq(lit("a")),
+            col("x").between(lit("a"), lit("z")),
+            col("x").not_between(lit("a"), lit("z")),
+            Expr::BinaryExpr(BinaryExpr {
+                left: Box::new(col("x")),
+                op: Operator::IsDistinctFrom,
+                right: Box::new(lit(ScalarValue::Null)),
+            }),
+        ];
+
+        validate_unchanged_cases(&mut rewriter, unchanged_cases);
+    }
+
+    #[test]
+    fn test_column_single_value() {
+        let scalars = [
+            ScalarValue::Null,
+            ScalarValue::Int32(Some(1)),
+            ScalarValue::Boolean(Some(true)),
+            ScalarValue::Boolean(None),
+            ScalarValue::Utf8(Some("abc".to_string())),
+            ScalarValue::LargeUtf8(Some("def".to_string())),
+            ScalarValue::Date32(Some(18628)),
+            ScalarValue::Date32(None),
+            ScalarValue::Decimal128(Some(1000), 19, 2),
+        ];
+
+        for scalar in scalars {
+            let guarantees = vec![(col("x"), 
NullableInterval::from(scalar.clone()))];
+            let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+            let output = col("x").rewrite(&mut rewriter).unwrap();
+            assert_eq!(output, Expr::Literal(scalar.clone()));
+        }
+    }
+
+    #[test]
+    fn test_in_list() {
+        let guarantees = vec![
+            // x ∈ [1, 10) (not null)
+            (
+                col("x"),
+                NullableInterval::NotNull {
+                    values: Interval::make(Some(1_i32), Some(10_i32), (false, 
true)),
+                },
+            ),
+        ];
+        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
+
+        // These cases should be simplified so the list doesn't contain any
+        // values the guarantee says are outside the range.
+        // (column_name, starting_list, negated, expected_list)
+        let cases = &[
+            // x IN (9, 11) => x IN (9)
+            ("x", vec![9, 11], false, vec![9]),
+            // x IN (10, 2) => x IN (2)
+            ("x", vec![10, 2], false, vec![2]),
+            // x NOT IN (9, 11) => x NOT IN (9)
+            ("x", vec![9, 11], true, vec![9]),
+            // x NOT IN (0, 22) => x NOT IN ()
+            ("x", vec![0, 22], true, vec![]),
+        ];
+
+        for (column_name, starting_list, negated, expected_list) in cases {
+            let expr = col(*column_name).in_list(
+                starting_list
+                    .iter()
+                    .map(|v| lit(ScalarValue::Int32(Some(*v))))
+                    .collect(),
+                *negated,
+            );
+            let output = expr.clone().rewrite(&mut rewriter).unwrap();
+            let expected_list = expected_list
+                .iter()
+                .map(|v| lit(ScalarValue::Int32(Some(*v))))
+                .collect();
+            assert_eq!(
+                output,
+                Expr::InList(InList {
+                    expr: Box::new(col(*column_name)),
+                    list: expected_list,
+                    negated: *negated,
+                })
+            );
+        }
+    }
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs 
b/datafusion/optimizer/src/simplify_expressions/mod.rs
index dfa0fe7043..2cf6ed166c 100644
--- a/datafusion/optimizer/src/simplify_expressions/mod.rs
+++ b/datafusion/optimizer/src/simplify_expressions/mod.rs
@@ -17,6 +17,7 @@
 
 pub mod context;
 pub mod expr_simplifier;
+mod guarantees;
 mod or_in_list_simplifier;
 mod regex;
 pub mod simplify_exprs;
diff --git a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs 
b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
index 3f72ef588c..5501c8cae0 100644
--- a/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
+++ b/datafusion/physical-expr/src/intervals/interval_aritmetic.rs
@@ -396,6 +396,22 @@ impl Interval {
         }
     }
 
+    /// Compute the logical negation of this (boolean) interval.
+    pub(crate) fn not(&self) -> Result<Self> {
+        if !matches!(self.get_datatype()?, DataType::Boolean) {
+            return internal_err!(
+                "Cannot apply logical negation to non-boolean interval"
+            );
+        }
+        if self == &Interval::CERTAINLY_TRUE {
+            Ok(Interval::CERTAINLY_FALSE)
+        } else if self == &Interval::CERTAINLY_FALSE {
+            Ok(Interval::CERTAINLY_TRUE)
+        } else {
+            Ok(Interval::UNCERTAIN)
+        }
+    }
+
     /// Compute the intersection of the interval with the given interval.
     /// If the intersection is empty, return None.
     pub(crate) fn intersect<T: Borrow<Interval>>(
@@ -426,6 +442,23 @@ impl Interval {
         Ok(non_empty.then_some(Interval::new(lower, upper)))
     }
 
+    /// Decide if this interval is certainly contains, possibly contains,
+    /// or can't can't `other` by returning [true, true],
+    /// [false, true] or [false, false] respectively.
+    pub fn contains<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
+        match self.intersect(other.borrow())? {
+            Some(intersection) => {
+                // Need to compare with same bounds close-ness.
+                if intersection.close_bounds() == 
other.borrow().clone().close_bounds() {
+                    Ok(Interval::CERTAINLY_TRUE)
+                } else {
+                    Ok(Interval::UNCERTAIN)
+                }
+            }
+            None => Ok(Interval::CERTAINLY_FALSE),
+        }
+    }
+
     /// Add the given interval (`other`) to this interval. Say we have
     /// intervals [a1, b1] and [a2, b2], then their sum is [a1 + a2, b1 + b2].
     /// Note that this represents all possible values the sum can take if
@@ -633,6 +666,7 @@ pub fn cardinality_ratio(
 pub fn apply_operator(op: &Operator, lhs: &Interval, rhs: &Interval) -> 
Result<Interval> {
     match *op {
         Operator::Eq => Ok(lhs.equal(rhs)),
+        Operator::NotEq => Ok(lhs.equal(rhs).not()?),
         Operator::Gt => Ok(lhs.gt(rhs)),
         Operator::GtEq => Ok(lhs.gt_eq(rhs)),
         Operator::Lt => Ok(lhs.lt(rhs)),
@@ -667,6 +701,283 @@ fn calculate_cardinality_based_on_bounds(
     }
 }
 
+/// An [Interval] that also tracks null status using a boolean interval.
+///
+/// This represents values that may be in a particular range or be null.
+///
+/// # Examples
+///
+/// ```
+/// use arrow::datatypes::DataType;
+/// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
+/// use datafusion_common::ScalarValue;
+///
+/// // [1, 2) U {NULL}
+/// NullableInterval::MaybeNull {
+///    values: Interval::make(Some(1), Some(2), (false, true)),
+/// };
+///
+/// // (0, ∞)
+/// NullableInterval::NotNull {
+///   values: Interval::make(Some(0), None, (true, true)),
+/// };
+///
+/// // {NULL}
+/// NullableInterval::Null { datatype: DataType::Int32 };
+///
+/// // {4}
+/// NullableInterval::from(ScalarValue::Int32(Some(4)));
+/// ```
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum NullableInterval {
+    /// The value is always null in this interval
+    ///
+    /// This is typed so it can be used in physical expressions, which don't do
+    /// type coercion.
+    Null { datatype: DataType },
+    /// The value may or may not be null in this interval. If it is non null 
its value is within
+    /// the specified values interval
+    MaybeNull { values: Interval },
+    /// The value is definitely not null in this interval and is within values
+    NotNull { values: Interval },
+}
+
+impl Default for NullableInterval {
+    fn default() -> Self {
+        NullableInterval::MaybeNull {
+            values: Interval::default(),
+        }
+    }
+}
+
+impl Display for NullableInterval {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        match self {
+            Self::Null { .. } => write!(f, "NullableInterval: {{NULL}}"),
+            Self::MaybeNull { values } => {
+                write!(f, "NullableInterval: {} U {{NULL}}", values)
+            }
+            Self::NotNull { values } => write!(f, "NullableInterval: {}", 
values),
+        }
+    }
+}
+
+impl From<ScalarValue> for NullableInterval {
+    /// Create an interval that represents a single value.
+    fn from(value: ScalarValue) -> Self {
+        if value.is_null() {
+            Self::Null {
+                datatype: value.data_type(),
+            }
+        } else {
+            Self::NotNull {
+                values: Interval::new(
+                    IntervalBound::new(value.clone(), false),
+                    IntervalBound::new(value, false),
+                ),
+            }
+        }
+    }
+}
+
+impl NullableInterval {
+    /// Get the values interval, or None if this interval is definitely null.
+    pub fn values(&self) -> Option<&Interval> {
+        match self {
+            Self::Null { .. } => None,
+            Self::MaybeNull { values } | Self::NotNull { values } => 
Some(values),
+        }
+    }
+
+    /// Get the data type
+    pub fn get_datatype(&self) -> Result<DataType> {
+        match self {
+            Self::Null { datatype } => Ok(datatype.clone()),
+            Self::MaybeNull { values } | Self::NotNull { values } => {
+                values.get_datatype()
+            }
+        }
+    }
+
+    /// Return true if the value is definitely true (and not null).
+    pub fn is_certainly_true(&self) -> bool {
+        match self {
+            Self::Null { .. } | Self::MaybeNull { .. } => false,
+            Self::NotNull { values } => values == &Interval::CERTAINLY_TRUE,
+        }
+    }
+
+    /// Return true if the value is definitely false (and not null).
+    pub fn is_certainly_false(&self) -> bool {
+        match self {
+            Self::Null { .. } => false,
+            Self::MaybeNull { .. } => false,
+            Self::NotNull { values } => values == &Interval::CERTAINLY_FALSE,
+        }
+    }
+
+    /// Perform logical negation on a boolean nullable interval.
+    fn not(&self) -> Result<Self> {
+        match self {
+            Self::Null { datatype } => Ok(Self::Null {
+                datatype: datatype.clone(),
+            }),
+            Self::MaybeNull { values } => Ok(Self::MaybeNull {
+                values: values.not()?,
+            }),
+            Self::NotNull { values } => Ok(Self::NotNull {
+                values: values.not()?,
+            }),
+        }
+    }
+
+    /// Apply the given operator to this interval and the given interval.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use datafusion_common::ScalarValue;
+    /// use datafusion_expr::Operator;
+    /// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
+    ///
+    /// // 4 > 3 -> true
+    /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4)));
+    /// let rhs = NullableInterval::from(ScalarValue::Int32(Some(3)));
+    /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap();
+    /// assert_eq!(result, 
NullableInterval::from(ScalarValue::Boolean(Some(true))));
+    ///
+    /// // [1, 3) > NULL -> NULL
+    /// let lhs = NullableInterval::NotNull {
+    ///     values: Interval::make(Some(1), Some(3), (false, true)),
+    /// };
+    /// let rhs = NullableInterval::from(ScalarValue::Int32(None));
+    /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap();
+    /// assert_eq!(result.single_value(), Some(ScalarValue::Boolean(None)));
+    ///
+    /// // [1, 3] > [2, 4] -> [false, true]
+    /// let lhs = NullableInterval::NotNull {
+    ///     values: Interval::make(Some(1), Some(3), (false, false)),
+    /// };
+    /// let rhs = NullableInterval::NotNull {
+    ///    values: Interval::make(Some(2), Some(4), (false, false)),
+    /// };
+    /// let result = lhs.apply_operator(&Operator::Gt, &rhs).unwrap();
+    /// // Both inputs are valid (non-null), so result must be non-null
+    /// assert_eq!(result, NullableInterval::NotNull {
+    ///    // Uncertain whether inequality is true or false
+    ///    values: Interval::UNCERTAIN,
+    /// });
+    ///
+    /// ```
+    pub fn apply_operator(&self, op: &Operator, rhs: &Self) -> Result<Self> {
+        match op {
+            Operator::IsDistinctFrom => {
+                let values = match (self, rhs) {
+                    // NULL is distinct from NULL -> False
+                    (Self::Null { .. }, Self::Null { .. }) => 
Interval::CERTAINLY_FALSE,
+                    // x is distinct from y -> x != y,
+                    // if at least one of them is never null.
+                    (Self::NotNull { .. }, _) | (_, Self::NotNull { .. }) => {
+                        let lhs_values = self.values();
+                        let rhs_values = rhs.values();
+                        match (lhs_values, rhs_values) {
+                            (Some(lhs_values), Some(rhs_values)) => {
+                                lhs_values.equal(rhs_values).not()?
+                            }
+                            (Some(_), None) | (None, Some(_)) => 
Interval::CERTAINLY_TRUE,
+                            (None, None) => unreachable!("Null case handled 
above"),
+                        }
+                    }
+                    _ => Interval::UNCERTAIN,
+                };
+                // IsDistinctFrom never returns null.
+                Ok(Self::NotNull { values })
+            }
+            Operator::IsNotDistinctFrom => self
+                .apply_operator(&Operator::IsDistinctFrom, rhs)
+                .map(|i| i.not())?,
+            _ => {
+                if let (Some(left_values), Some(right_values)) =
+                    (self.values(), rhs.values())
+                {
+                    let values = apply_operator(op, left_values, 
right_values)?;
+                    match (self, rhs) {
+                        (Self::NotNull { .. }, Self::NotNull { .. }) => {
+                            Ok(Self::NotNull { values })
+                        }
+                        _ => Ok(Self::MaybeNull { values }),
+                    }
+                } else if op.is_comparison_operator() {
+                    Ok(Self::Null {
+                        datatype: DataType::Boolean,
+                    })
+                } else {
+                    Ok(Self::Null {
+                        datatype: self.get_datatype()?,
+                    })
+                }
+            }
+        }
+    }
+
+    /// Determine if this interval contains the given interval. Returns a 
boolean
+    /// interval that is [true, true] if this interval is a superset of the
+    /// given interval, [false, false] if this interval is disjoint from the
+    /// given interval, and [false, true] otherwise.
+    pub fn contains<T: Borrow<Self>>(&self, other: T) -> Result<Self> {
+        let rhs = other.borrow();
+        if let (Some(left_values), Some(right_values)) = (self.values(), 
rhs.values()) {
+            let values = left_values.contains(right_values)?;
+            match (self, rhs) {
+                (Self::NotNull { .. }, Self::NotNull { .. }) => {
+                    Ok(Self::NotNull { values })
+                }
+                _ => Ok(Self::MaybeNull { values }),
+            }
+        } else {
+            Ok(Self::Null {
+                datatype: DataType::Boolean,
+            })
+        }
+    }
+
+    /// If the interval has collapsed to a single value, return that value.
+    ///
+    /// Otherwise returns None.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use datafusion_common::ScalarValue;
+    /// use datafusion_physical_expr::intervals::{Interval, NullableInterval};
+    ///
+    /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4)));
+    /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4))));
+    ///
+    /// let interval = NullableInterval::from(ScalarValue::Int32(None));
+    /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(None)));
+    ///
+    /// let interval = NullableInterval::MaybeNull {
+    ///     values: Interval::make(Some(1), Some(4), (false, true)),
+    /// };
+    /// assert_eq!(interval.single_value(), None);
+    /// ```
+    pub fn single_value(&self) -> Option<ScalarValue> {
+        match self {
+            Self::Null { datatype } => {
+                
Some(ScalarValue::try_from(datatype).unwrap_or(ScalarValue::Null))
+            }
+            Self::MaybeNull { values } | Self::NotNull { values }
+                if values.lower.value == values.upper.value
+                    && !values.lower.is_unbounded() =>
+            {
+                Some(values.lower.value.clone())
+            }
+            _ => None,
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::next_value;


Reply via email to