This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f0e96c6701 feat: run expression simplifier in a loop until a 
fixedpoint or 3 cycles (#10358)
f0e96c6701 is described below

commit f0e96c670108ba0ffdebb9dd9e764bba4d2dca8c
Author: Adam Curtis <[email protected]>
AuthorDate: Tue May 7 06:43:34 2024 -0400

    feat: run expression simplifier in a loop until a fixedpoint or 3 cycles 
(#10358)
    
    * feat: run expression simplifier in a loop
    
    * change max_simplifier_iterations to u32
    
    * use simplify_inner to explicitly test iteration count
    
    * refactor simplify_inner loop
    
    * const evaluator should return transformed=false on literals
    
    * update tests
    
    * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * run shorten_in_list_simplifier once at the end of the loop
    
    * move UDF test case to core integration tests
    
    * documentation and naming updates
    
    * documentation and naming updates
    
    * remove unused import and minor doc formatting change
    
    * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/tests/simplification.rs            |  31 ++++
 .../src/simplify_expressions/expr_simplifier.rs    | 175 ++++++++++++++++++---
 2 files changed, 182 insertions(+), 24 deletions(-)

diff --git a/datafusion/core/tests/simplification.rs 
b/datafusion/core/tests/simplification.rs
index 880c294bb7..bb41929834 100644
--- a/datafusion/core/tests/simplification.rs
+++ b/datafusion/core/tests/simplification.rs
@@ -508,6 +508,29 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) {
         "Mismatch evaluating {input_expr}\n  Expected:{expected_expr}\n  
Got:{simplified_expr}"
     );
 }
+fn test_simplify_with_cycle_count(
+    input_expr: Expr,
+    expected_expr: Expr,
+    expected_count: u32,
+) {
+    let info: MyInfo = MyInfo {
+        schema: expr_test_schema(),
+        execution_props: ExecutionProps::new(),
+    };
+    let simplifier = ExprSimplifier::new(info);
+    let (simplified_expr, count) = simplifier
+        .simplify_with_cycle_count(input_expr.clone())
+        .expect("successfully evaluated");
+
+    assert_eq!(
+        simplified_expr, expected_expr,
+        "Mismatch evaluating {input_expr}\n  Expected:{expected_expr}\n  
Got:{simplified_expr}"
+    );
+    assert_eq!(
+        count, expected_count,
+        "Mismatch simplifier cycle count\n Expected: {expected_count}\n 
Got:{count}"
+    );
+}
 
 #[test]
 fn test_simplify_log() {
@@ -658,3 +681,11 @@ fn test_simplify_concat() {
     let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
     test_simplify(expr, expected)
 }
+#[test]
+fn test_simplify_cycles() {
+    // cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX
+    let expr = cast(now(), DataType::Int64)
+        .lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
+    let expected = lit(true);
+    test_simplify_with_cycle_count(expr, expected, 3);
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 4d7a207afb..0f711d6a2c 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -92,9 +92,12 @@ pub struct ExprSimplifier<S> {
     /// Should expressions be canonicalized before simplification? Defaults to
     /// true
     canonicalize: bool,
+    /// Maximum number of simplifier cycles
+    max_simplifier_cycles: u32,
 }
 
 pub const THRESHOLD_INLINE_INLIST: usize = 3;
+pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
 
 impl<S: SimplifyInfo> ExprSimplifier<S> {
     /// Create a new `ExprSimplifier` with the given `info` such as an
@@ -107,10 +110,11 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
             info,
             guarantees: vec![],
             canonicalize: true,
+            max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
         }
     }
 
-    /// Simplifies this [`Expr`]`s as much as possible, evaluating
+    /// Simplifies this [`Expr`] as much as possible, evaluating
     /// constants and applying algebraic simplifications.
     ///
     /// The types of the expression must match what operators expect,
@@ -171,7 +175,18 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
     /// let expr = simplifier.simplify(expr).unwrap();
     /// assert_eq!(expr, b_lt_2);
     /// ```
-    pub fn simplify(&self, mut expr: Expr) -> Result<Expr> {
+    pub fn simplify(&self, expr: Expr) -> Result<Expr> {
+        Ok(self.simplify_with_cycle_count(expr)?.0)
+    }
+
+    /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, 
evaluating
+    /// constants and applying algebraic simplifications. Additionally returns 
a `u32`
+    /// representing the number of simplification cycles performed, which can 
be useful for testing
+    /// optimizations.
+    ///
+    /// See [Self::simplify] for details and usage examples.
+    ///
+    pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, 
u32)> {
         let mut simplifier = Simplifier::new(&self.info);
         let mut const_evaluator = 
ConstEvaluator::try_new(self.info.execution_props())?;
         let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
@@ -181,24 +196,26 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
             expr = expr.rewrite(&mut Canonicalizer::new()).data()?
         }
 
-        // TODO iterate until no changes are made during rewrite
-        // (evaluating constants can enable new simplifications and
-        // simplifications can enable new constant evaluation)
-        // https://github.com/apache/datafusion/issues/1160
-        expr.rewrite(&mut const_evaluator)
-            .data()?
-            .rewrite(&mut simplifier)
-            .data()?
-            .rewrite(&mut guarantee_rewriter)
-            .data()?
-            // run both passes twice to try an minimize simplifications that 
we missed
-            .rewrite(&mut const_evaluator)
-            .data()?
-            .rewrite(&mut simplifier)
-            .data()?
-            // shorten inlist should be started after other inlist rules are 
applied
-            .rewrite(&mut shorten_in_list_simplifier)
-            .data()
+        // Evaluating constants can enable new simplifications and
+        // simplifications can enable new constant evaluation
+        // see `Self::with_max_cycles`
+        let mut num_cycles = 0;
+        loop {
+            let Transformed {
+                data, transformed, ..
+            } = expr
+                .rewrite(&mut const_evaluator)?
+                .transform_data(|expr| expr.rewrite(&mut simplifier))?
+                .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
+            expr = data;
+            num_cycles += 1;
+            if !transformed || num_cycles >= self.max_simplifier_cycles {
+                break;
+            }
+        }
+        // shorten inlist should be started after other inlist rules are 
applied
+        expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
+        Ok((expr, num_cycles))
     }
 
     /// Apply type coercion to an [`Expr`] so that it can be
@@ -323,6 +340,63 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
         self.canonicalize = canonicalize;
         self
     }
+
+    /// Specifies the maximum number of simplification cycles to run.
+    ///
+    /// The simplifier can perform multiple passes of simplification. This is
+    /// because the output of one simplification step can allow more 
optimizations
+    /// in another simplification step. For example, constant evaluation can 
allow more
+    /// expression simplifications, and expression simplifications can allow 
more constant
+    /// evaluations.
+    ///
+    /// This method specifies the maximum number of allowed iteration cycles 
before the simplifier
+    /// returns an [Expr] output. However, it does not always perform the 
maximum number of cycles.
+    /// The simplifier will attempt to detect when an [Expr] is unchanged by 
all the simplification
+    /// passes, and return early. This avoids wasting time on unnecessary 
[Expr] tree traversals.
+    ///
+    /// If no maximum is specified, the value of 
[DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
+    /// instead.
+    ///
+    /// ```rust
+    /// use arrow::datatypes::{DataType, Field, Schema};
+    /// use datafusion_expr::{col, lit, Expr};
+    /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
+    /// use datafusion_expr::execution_props::ExecutionProps;
+    /// use datafusion_expr::simplify::SimplifyContext;
+    /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
+    ///
+    /// let schema = Schema::new(vec![
+    ///   Field::new("a", DataType::Int64, false),
+    ///   ])
+    ///   .to_dfschema_ref().unwrap();
+    ///
+    /// // Create the simplifier
+    /// let props = ExecutionProps::new();
+    /// let context = SimplifyContext::new(&props)
+    ///    .with_schema(schema);
+    /// let simplifier = ExprSimplifier::new(context);
+    ///
+    /// // Expression: a IS NOT NULL
+    /// let expr = col("a").is_not_null();
+    ///
+    /// // When using default maximum cycles, 2 cycles will be performed.
+    /// let (simplified_expr, count) = 
simplifier.simplify_with_cycle_count(expr.clone()).unwrap();
+    /// assert_eq!(simplified_expr, lit(true));
+    /// // 2 cycles were executed, but only 1 was needed
+    /// assert_eq!(count, 2);
+    ///
+    /// // Only 1 simplification pass is necessary here, so we can set the 
maximum cycles to 1.
+    /// let (simplified_expr, count) = 
simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap();
+    /// // Expression has been rewritten to: (c = a AND b = 1)
+    /// assert_eq!(simplified_expr, lit(true));
+    /// // Only 1 cycle was executed
+    /// assert_eq!(count, 1);
+    ///
+    /// ```
+    pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
+        self.max_simplifier_cycles = max_simplifier_cycles;
+        self
+    }
 }
 
 /// Canonicalize any BinaryExprs that are not in canonical form
@@ -404,6 +478,8 @@ struct ConstEvaluator<'a> {
 enum ConstSimplifyResult {
     // Expr was simplifed and contains the new expression
     Simplified(ScalarValue),
+    // Expr was not simplified and original value is returned
+    NotSimplified(ScalarValue),
     // Evaluation encountered an error, contains the original expression
     SimplifyRuntimeError(DataFusionError, Expr),
 }
@@ -450,6 +526,9 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
                     ConstSimplifyResult::Simplified(s) => {
                         Ok(Transformed::yes(Expr::Literal(s)))
                     }
+                    ConstSimplifyResult::NotSimplified(s) => {
+                        Ok(Transformed::no(Expr::Literal(s)))
+                    }
                     ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
                         Ok(Transformed::yes(expr))
                     }
@@ -548,7 +627,7 @@ impl<'a> ConstEvaluator<'a> {
     /// Internal helper to evaluates an Expr
     pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> 
ConstSimplifyResult {
         if let Expr::Literal(s) = expr {
-            return ConstSimplifyResult::Simplified(s);
+            return ConstSimplifyResult::NotSimplified(s);
         }
 
         let phys_expr =
@@ -1672,15 +1751,14 @@ fn inlist_except(mut l1: InList, l2: InList) -> 
Result<Expr> {
 
 #[cfg(test)]
 mod tests {
+    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
+    use datafusion_expr::{interval_arithmetic::Interval, *};
     use std::{
         collections::HashMap,
         ops::{BitAnd, BitOr, BitXor},
         sync::Arc,
     };
 
-    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
-    use datafusion_expr::{interval_arithmetic::Interval, *};
-
     use crate::simplify_expressions::SimplifyContext;
     use crate::test::test_table_scan_with_name;
 
@@ -2868,6 +2946,19 @@ mod tests {
         try_simplify(expr).unwrap()
     }
 
+    fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
+        let schema = expr_test_schema();
+        let execution_props = ExecutionProps::new();
+        let simplifier = ExprSimplifier::new(
+            SimplifyContext::new(&execution_props).with_schema(schema),
+        );
+        simplifier.simplify_with_cycle_count(expr)
+    }
+
+    fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
+        try_simplify_with_cycle_count(expr).unwrap()
+    }
+
     fn simplify_with_guarantee(
         expr: Expr,
         guarantees: Vec<(Expr, NullableInterval)>,
@@ -3575,4 +3666,40 @@ mod tests {
 
         assert_eq!(simplify(expr), expected);
     }
+
+    #[test]
+    fn test_simplify_cycles() {
+        // TRUE
+        let expr = lit(true);
+        let expected = lit(true);
+        let (expr, num_iter) = simplify_with_cycle_count(expr);
+        assert_eq!(expr, expected);
+        assert_eq!(num_iter, 1);
+
+        // (true != NULL) OR (5 > 10)
+        let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
+        let expected = lit_bool_null();
+        let (expr, num_iter) = simplify_with_cycle_count(expr);
+        assert_eq!(expr, expected);
+        assert_eq!(num_iter, 2);
+
+        // NOTE: this currently does not simplify
+        // (((c4 - 10) + 10) *100) / 100
+        let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
+        let expected = expr.clone();
+        let (expr, num_iter) = simplify_with_cycle_count(expr);
+        assert_eq!(expr, expected);
+        assert_eq!(num_iter, 1);
+
+        // ((c4<1 or c3<2) and c3_non_null<3) and false
+        let expr = col("c4")
+            .lt(lit(1))
+            .or(col("c3").lt(lit(2)))
+            .and(col("c3_non_null").lt(lit(3)))
+            .and(lit(false));
+        let expected = lit(false);
+        let (expr, num_iter) = simplify_with_cycle_count(expr);
+        assert_eq!(expr, expected);
+        assert_eq!(num_iter, 2);
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to