This is an automated email from the ASF dual-hosted git repository.

wjones127 pushed a commit to branch 6171-simplify-with-guarantee
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 2134f2f18fbf0937ba8d0cc5abd8d0919d9b4d3b
Author: Will Jones <[email protected]>
AuthorDate: Mon Sep 4 10:22:06 2023 -0700

    add support for literal expressions
---
 .../src/simplify_expressions/guarantees.rs         | 93 ++++++++++++----------
 1 file changed, 49 insertions(+), 44 deletions(-)

diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs 
b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
index 3bca3345ae..4e142ef280 100644
--- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs
+++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs
@@ -121,9 +121,10 @@ impl From<&ScalarValue> for Guarantee {
                 bound: value.clone(),
                 open: false,
             },
-            null_status: match value {
-                ScalarValue::Null => NullStatus::AlwaysNull,
-                _ => NullStatus::NeverNull,
+            null_status: if value.is_null() {
+                NullStatus::AlwaysNull
+            } else {
+                NullStatus::NeverNull
             },
         }
     }
@@ -318,6 +319,25 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> {
             }
 
             // Columns (if bounds are equal and closed and column is not 
nullable)
+            Expr::Column(_) => {
+                if let Some(guarantee) = self.guarantees.get(&expr) {
+                    if guarantee.min == guarantee.max
+                        // Case where column has a single valid value
+                        && ((!guarantee.min.open
+                            && !guarantee.min.bound.is_null()
+                            && guarantee.null_status == NullStatus::NeverNull)
+                            // Case where column is always null
+                            || (guarantee.min.bound.is_null()
+                                && guarantee.null_status == 
NullStatus::AlwaysNull))
+                    {
+                        Ok(lit(guarantee.min.bound.clone()))
+                    } else {
+                        Ok(expr)
+                    }
+                } else {
+                    Ok(expr)
+                }
+            }
 
             // In list
             _ => Ok(expr),
@@ -336,25 +356,21 @@ mod tests {
     fn test_null_handling() {
         // IsNull / IsNotNull can be rewritten to true / false
         let guarantees = vec![
-            (col("x"), Guarantee::new(None, None, NullStatus::AlwaysNull)),
-            (col("y"), Guarantee::new(None, None, NullStatus::NeverNull)),
+            // Note: AlwaysNull case handled by test_column_single_value test,
+            // since it's a special case of a column with a single value.
+            (col("x"), Guarantee::new(None, None, NullStatus::NeverNull)),
         ];
         let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
 
-        let cases = &[
-            (col("x").is_null(), true),
-            (col("x").is_not_null(), false),
-            (col("y").is_null(), false),
-            (col("y").is_not_null(), true),
-        ];
+        // x IS NULL => guaranteed false
+        let expr = col("x").is_null();
+        let output = expr.clone().rewrite(&mut rewriter).unwrap();
+        assert_eq!(output, lit(false));
 
-        for (expr, expected_value) in cases {
-            let output = expr.clone().rewrite(&mut rewriter).unwrap();
-            assert_eq!(
-                output,
-                Expr::Literal(ScalarValue::Boolean(Some(*expected_value)))
-            );
-        }
+        // x IS NOT NULL => guaranteed true
+        let expr = col("x").is_not_null();
+        let output = expr.clone().rewrite(&mut rewriter).unwrap();
+        assert_eq!(output, lit(true));
     }
 
     #[test]
@@ -431,35 +447,24 @@ mod tests {
 
     #[test]
     fn test_column_single_value() {
-        let guarantees = vec![
-            // x = 2
-            (col("x"), Guarantee::from(&ScalarValue::Int32(Some(2)))),
-            // y is Null
-            (col("y"), Guarantee::from(&ScalarValue::Null)),
+        let scalars = [
+            ScalarValue::Null,
+            ScalarValue::Int32(Some(1)),
+            ScalarValue::Boolean(Some(true)),
+            ScalarValue::Boolean(None),
+            ScalarValue::Utf8(Some("abc".to_string())),
+            ScalarValue::LargeUtf8(Some("def".to_string())),
+            ScalarValue::Date32(Some(18628)),
+            ScalarValue::Date32(None),
+            ScalarValue::Decimal128(Some(1000), 19, 2),
         ];
-        let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
 
-        // These cases should be simplified
-        let cases = &[
-            (col("x").lt_eq(lit(1)), false),
-            (col("x").gt(lit(3)), false),
-            (col("x").eq(lit(1)), false),
-            (col("x").eq(lit(2)), true),
-            (col("x").gt(lit(1)), true),
-            (col("x").lt_eq(lit(2)), true),
-            (col("x").is_not_null(), true),
-            (col("x").is_null(), false),
-            (col("y").is_null(), true),
-            (col("y").is_not_null(), false),
-            (col("y").lt_eq(lit(17000)), false),
-        ];
+        for scalar in &scalars {
+            let guarantees = vec![(col("x"), Guarantee::from(scalar))];
+            let mut rewriter = GuaranteeRewriter::new(guarantees.iter());
 
-        for (expr, expected_value) in cases {
-            let output = expr.clone().rewrite(&mut rewriter).unwrap();
-            assert_eq!(
-                output,
-                Expr::Literal(ScalarValue::Boolean(Some(*expected_value)))
-            );
+            let output = col("x").rewrite(&mut rewriter).unwrap();
+            assert_eq!(output, Expr::Literal(scalar.clone()));
         }
     }
 

Reply via email to