This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f0e96c6701 feat: run expression simplifier in a loop until a
fixedpoint or 3 cycles (#10358)
f0e96c6701 is described below
commit f0e96c670108ba0ffdebb9dd9e764bba4d2dca8c
Author: Adam Curtis <[email protected]>
AuthorDate: Tue May 7 06:43:34 2024 -0400
feat: run expression simplifier in a loop until a fixedpoint or 3 cycles
(#10358)
* feat: run expression simplifier in a loop
* change max_simplifier_iterations to u32
* use simplify_inner to explicitly test iteration count
* refactor simplify_inner loop
* const evaluator should return transformed=false on literals
* update tests
* Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Co-authored-by: Andrew Lamb <[email protected]>
* run shorten_in_list_simplifier once at the end of the loop
* move UDF test case to core integration tests
* documentation and naming updates
* documentation and naming updates
* remove unused import and minor doc formatting change
* Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/tests/simplification.rs | 31 ++++
.../src/simplify_expressions/expr_simplifier.rs | 175 ++++++++++++++++++---
2 files changed, 182 insertions(+), 24 deletions(-)
diff --git a/datafusion/core/tests/simplification.rs
b/datafusion/core/tests/simplification.rs
index 880c294bb7..bb41929834 100644
--- a/datafusion/core/tests/simplification.rs
+++ b/datafusion/core/tests/simplification.rs
@@ -508,6 +508,29 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) {
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n
Got:{simplified_expr}"
);
}
+fn test_simplify_with_cycle_count(
+ input_expr: Expr,
+ expected_expr: Expr,
+ expected_count: u32,
+) {
+ let info: MyInfo = MyInfo {
+ schema: expr_test_schema(),
+ execution_props: ExecutionProps::new(),
+ };
+ let simplifier = ExprSimplifier::new(info);
+ let (simplified_expr, count) = simplifier
+ .simplify_with_cycle_count(input_expr.clone())
+ .expect("successfully evaluated");
+
+ assert_eq!(
+ simplified_expr, expected_expr,
+ "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n
Got:{simplified_expr}"
+ );
+ assert_eq!(
+ count, expected_count,
+ "Mismatch simplifier cycle count\n Expected: {expected_count}\n
Got:{count}"
+ );
+}
#[test]
fn test_simplify_log() {
@@ -658,3 +681,11 @@ fn test_simplify_concat() {
let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
test_simplify(expr, expected)
}
+#[test]
+fn test_simplify_cycles() {
+ // cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX
+ let expr = cast(now(), DataType::Int64)
+ .lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
+ let expected = lit(true);
+ test_simplify_with_cycle_count(expr, expected, 3);
+}
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index 4d7a207afb..0f711d6a2c 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -92,9 +92,12 @@ pub struct ExprSimplifier<S> {
/// Should expressions be canonicalized before simplification? Defaults to
/// true
canonicalize: bool,
+ /// Maximum number of simplifier cycles
+ max_simplifier_cycles: u32,
}
pub const THRESHOLD_INLINE_INLIST: usize = 3;
+pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;
impl<S: SimplifyInfo> ExprSimplifier<S> {
/// Create a new `ExprSimplifier` with the given `info` such as an
@@ -107,10 +110,11 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
info,
guarantees: vec![],
canonicalize: true,
+ max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
}
}
- /// Simplifies this [`Expr`]`s as much as possible, evaluating
+ /// Simplifies this [`Expr`] as much as possible, evaluating
/// constants and applying algebraic simplifications.
///
/// The types of the expression must match what operators expect,
@@ -171,7 +175,18 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// let expr = simplifier.simplify(expr).unwrap();
/// assert_eq!(expr, b_lt_2);
/// ```
- pub fn simplify(&self, mut expr: Expr) -> Result<Expr> {
+ pub fn simplify(&self, expr: Expr) -> Result<Expr> {
+ Ok(self.simplify_with_cycle_count(expr)?.0)
+ }
+
+ /// Like [Self::simplify], simplifies this [`Expr`] as much as possible,
evaluating
+ /// constants and applying algebraic simplifications. Additionally returns
a `u32`
+ /// representing the number of simplification cycles performed, which can
be useful for testing
+ /// optimizations.
+ ///
+ /// See [Self::simplify] for details and usage examples.
+ ///
+ pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr,
u32)> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator =
ConstEvaluator::try_new(self.info.execution_props())?;
let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
@@ -181,24 +196,26 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
expr = expr.rewrite(&mut Canonicalizer::new()).data()?
}
- // TODO iterate until no changes are made during rewrite
- // (evaluating constants can enable new simplifications and
- // simplifications can enable new constant evaluation)
- // https://github.com/apache/datafusion/issues/1160
- expr.rewrite(&mut const_evaluator)
- .data()?
- .rewrite(&mut simplifier)
- .data()?
- .rewrite(&mut guarantee_rewriter)
- .data()?
- // run both passes twice to try an minimize simplifications that
we missed
- .rewrite(&mut const_evaluator)
- .data()?
- .rewrite(&mut simplifier)
- .data()?
- // shorten inlist should be started after other inlist rules are
applied
- .rewrite(&mut shorten_in_list_simplifier)
- .data()
+ // Evaluating constants can enable new simplifications and
+ // simplifications can enable new constant evaluation
+ // see `Self::with_max_cycles`
+ let mut num_cycles = 0;
+ loop {
+ let Transformed {
+ data, transformed, ..
+ } = expr
+ .rewrite(&mut const_evaluator)?
+ .transform_data(|expr| expr.rewrite(&mut simplifier))?
+ .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
+ expr = data;
+ num_cycles += 1;
+ if !transformed || num_cycles >= self.max_simplifier_cycles {
+ break;
+ }
+ }
+ // shorten inlist should be started after other inlist rules are
applied
+ expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
+ Ok((expr, num_cycles))
}
/// Apply type coercion to an [`Expr`] so that it can be
@@ -323,6 +340,63 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
self.canonicalize = canonicalize;
self
}
+
+ /// Specifies the maximum number of simplification cycles to run.
+ ///
+ /// The simplifier can perform multiple passes of simplification. This is
+ /// because the output of one simplification step can allow more
optimizations
+ /// in another simplification step. For example, constant evaluation can
allow more
+ /// expression simplifications, and expression simplifications can allow
more constant
+ /// evaluations.
+ ///
+ /// This method specifies the maximum number of allowed iteration cycles
before the simplifier
+ /// returns an [Expr] output. However, it does not always perform the
maximum number of cycles.
+ /// The simplifier will attempt to detect when an [Expr] is unchanged by
all the simplification
+ /// passes, and return early. This avoids wasting time on unnecessary
[Expr] tree traversals.
+ ///
+ /// If no maximum is specified, the value of
[DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
+ /// instead.
+ ///
+ /// ```rust
+ /// use arrow::datatypes::{DataType, Field, Schema};
+ /// use datafusion_expr::{col, lit, Expr};
+ /// use datafusion_common::{Result, ScalarValue, ToDFSchema};
+ /// use datafusion_expr::execution_props::ExecutionProps;
+ /// use datafusion_expr::simplify::SimplifyContext;
+ /// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
+ ///
+ /// let schema = Schema::new(vec![
+ /// Field::new("a", DataType::Int64, false),
+ /// ])
+ /// .to_dfschema_ref().unwrap();
+ ///
+ /// // Create the simplifier
+ /// let props = ExecutionProps::new();
+ /// let context = SimplifyContext::new(&props)
+ /// .with_schema(schema);
+ /// let simplifier = ExprSimplifier::new(context);
+ ///
+ /// // Expression: a IS NOT NULL
+ /// let expr = col("a").is_not_null();
+ ///
+ /// // When using default maximum cycles, 2 cycles will be performed.
+ /// let (simplified_expr, count) =
simplifier.simplify_with_cycle_count(expr.clone()).unwrap();
+ /// assert_eq!(simplified_expr, lit(true));
+ /// // 2 cycles were executed, but only 1 was needed
+ /// assert_eq!(count, 2);
+ ///
+ /// // Only 1 simplification pass is necessary here, so we can set the
maximum cycles to 1.
+ /// let (simplified_expr, count) =
simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap();
+ /// // Expression has been rewritten to: (c = a AND b = 1)
+ /// assert_eq!(simplified_expr, lit(true));
+ /// // Only 1 cycle was executed
+ /// assert_eq!(count, 1);
+ ///
+ /// ```
+ pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
+ self.max_simplifier_cycles = max_simplifier_cycles;
+ self
+ }
}
/// Canonicalize any BinaryExprs that are not in canonical form
@@ -404,6 +478,8 @@ struct ConstEvaluator<'a> {
enum ConstSimplifyResult {
// Expr was simplifed and contains the new expression
Simplified(ScalarValue),
+ // Expr was not simplified and original value is returned
+ NotSimplified(ScalarValue),
// Evaluation encountered an error, contains the original expression
SimplifyRuntimeError(DataFusionError, Expr),
}
@@ -450,6 +526,9 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
ConstSimplifyResult::Simplified(s) => {
Ok(Transformed::yes(Expr::Literal(s)))
}
+ ConstSimplifyResult::NotSimplified(s) => {
+ Ok(Transformed::no(Expr::Literal(s)))
+ }
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
Ok(Transformed::yes(expr))
}
@@ -548,7 +627,7 @@ impl<'a> ConstEvaluator<'a> {
/// Internal helper to evaluates an Expr
pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) ->
ConstSimplifyResult {
if let Expr::Literal(s) = expr {
- return ConstSimplifyResult::Simplified(s);
+ return ConstSimplifyResult::NotSimplified(s);
}
let phys_expr =
@@ -1672,15 +1751,14 @@ fn inlist_except(mut l1: InList, l2: InList) ->
Result<Expr> {
#[cfg(test)]
mod tests {
+ use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
+ use datafusion_expr::{interval_arithmetic::Interval, *};
use std::{
collections::HashMap,
ops::{BitAnd, BitOr, BitXor},
sync::Arc,
};
- use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
- use datafusion_expr::{interval_arithmetic::Interval, *};
-
use crate::simplify_expressions::SimplifyContext;
use crate::test::test_table_scan_with_name;
@@ -2868,6 +2946,19 @@ mod tests {
try_simplify(expr).unwrap()
}
+ fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
+ let schema = expr_test_schema();
+ let execution_props = ExecutionProps::new();
+ let simplifier = ExprSimplifier::new(
+ SimplifyContext::new(&execution_props).with_schema(schema),
+ );
+ simplifier.simplify_with_cycle_count(expr)
+ }
+
+ fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
+ try_simplify_with_cycle_count(expr).unwrap()
+ }
+
fn simplify_with_guarantee(
expr: Expr,
guarantees: Vec<(Expr, NullableInterval)>,
@@ -3575,4 +3666,40 @@ mod tests {
assert_eq!(simplify(expr), expected);
}
+
+ #[test]
+ fn test_simplify_cycles() {
+ // TRUE
+ let expr = lit(true);
+ let expected = lit(true);
+ let (expr, num_iter) = simplify_with_cycle_count(expr);
+ assert_eq!(expr, expected);
+ assert_eq!(num_iter, 1);
+
+ // (true != NULL) OR (5 > 10)
+ let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
+ let expected = lit_bool_null();
+ let (expr, num_iter) = simplify_with_cycle_count(expr);
+ assert_eq!(expr, expected);
+ assert_eq!(num_iter, 2);
+
+ // NOTE: this currently does not simplify
+ // (((c4 - 10) + 10) *100) / 100
+ let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
+ let expected = expr.clone();
+ let (expr, num_iter) = simplify_with_cycle_count(expr);
+ assert_eq!(expr, expected);
+ assert_eq!(num_iter, 1);
+
+ // ((c4<1 or c3<2) and c3_non_null<3) and false
+ let expr = col("c4")
+ .lt(lit(1))
+ .or(col("c3").lt(lit(2)))
+ .and(col("c3_non_null").lt(lit(3)))
+ .and(lit(false));
+ let expected = lit(false);
+ let (expr, num_iter) = simplify_with_cycle_count(expr);
+ assert_eq!(expr, expected);
+ assert_eq!(num_iter, 2);
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]