alamb commented on code in PR #19722:
URL: https://github.com/apache/datafusion/pull/19722#discussion_r2706112406


##########
datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs:
##########
@@ -1969,12 +1972,101 @@ impl TreeNodeRewriter for Simplifier<'_> {
                 }))
             }
 
+            // =======================================
+            // preimage_in_comparison
+            // =======================================
+            //
+            // For case:
+            // date_part('YEAR', expr) op literal
+            //
+            // Background:
+            // Datasources such as Parquet can prune partitions using simple 
predicates,
+            // but they cannot do so for complex expressions.
+            // For a complex predicate like `date_part('YEAR', c1) < 2000`, 
pruning is not possible.
+            // After rewriting it to `c1 < 2000-01-01`, pruning becomes 
feasible.
+            // Rewrites use inclusive lower and exclusive upper bounds when
+            // translating an equality into a range.
+            // NOTE: we only consider immutable UDFs with literal RHS values 
and
+            // UDFs that provide both `preimage` and `column_expr`.
+            Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
+                use datafusion_expr::Operator::*;
+                let is_preimage_op = matches!(
+                    op,
+                    Eq | NotEq
+                        | Lt
+                        | LtEq
+                        | Gt
+                        | GtEq
+                        | IsDistinctFrom
+                        | IsNotDistinctFrom
+                );
+                if !is_preimage_op {
+                    return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr {
+                        left,
+                        op,
+                        right,
+                    })));
+                }
+
+                if let (Some(interval), Some(col_expr)) =
+                    get_preimage(left.as_ref(), right.as_ref(), info)?
+                {
+                    rewrite_with_preimage(info, interval, op, 
Box::new(col_expr))?
+                } else if let Some(swapped) = op.swap() {
+                    if let (Some(interval), Some(col_expr)) =
+                        get_preimage(right.as_ref(), left.as_ref(), info)?
+                    {
+                        rewrite_with_preimage(
+                            info,
+                            interval,
+                            swapped,
+                            Box::new(col_expr),
+                        )?
+                    } else {
+                        Transformed::no(Expr::BinaryExpr(BinaryExpr { left, 
op, right }))
+                    }
+                } else {
+                    Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, 
right }))
+                }
+            }
+
             // no additional rewrites possible
             expr => Transformed::no(expr),
         })
     }
 }
 
+fn get_preimage(
+    left_expr: &Expr,
+    right_expr: &Expr,
+    info: &SimplifyContext,
+) -> Result<(Option<Interval>, Option<Expr>)> {
+    let Expr::ScalarFunction(ScalarFunction { func, args }) = left_expr else {
+        return Ok((None, None));
+    };
+    if !is_literal_or_literal_cast(right_expr) {

Review Comment:
   I wonder if there is a reason to limit this to literal ? It seems like the 
call to `pre_image` could handle this (and basically return if it wasn't a 
literal) 



##########
datafusion/optimizer/src/simplify_expressions/udf_preimage.rs:
##########
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::{Result, internal_err, tree_node::Transformed};
+use datafusion_expr::{
+    Expr, Operator, and, binary_expr, lit, or, simplify::SimplifyContext,
+};
+use datafusion_expr_common::interval_arithmetic::Interval;
+
+/// Rewrites a binary expression using its "preimage"
+///
+/// Specifically it rewrites expressions of the form `<expr> OP x` (e.g. 
`<expr> =
+/// x`) where `<expr>` is known to have a pre-image (aka the entire single
+/// range for which it is valid)
+///
+/// This rewrite is described in the [ClickHouse Paper] and is particularly
+/// useful for simplifying expressions `date_part` or equivalent functions. The
+/// idea is that if you have an expression like `date_part(YEAR, k) = 2024` 
and you
+/// can find a [preimage] for `date_part(YEAR, k)`, which is the range of dates
+/// covering the entire year of 2024. Thus, you can rewrite the expression to
+/// `k >= '2024-01-01' AND k < '2025-01-01'`, which uses an inclusive lower 
bound
+/// and exclusive upper bound and is often more optimizable.
+///
+/// [ClickHouse Paper]:  https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf
+/// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image
+///
+pub(super) fn rewrite_with_preimage(
+    _info: &SimplifyContext,
+    preimage_interval: Interval,
+    op: Operator,
+    expr: Box<Expr>,
+) -> Result<Transformed<Expr>> {
+    let (lower, upper) = preimage_interval.into_bounds();
+    let (lower, upper) = (lit(lower), lit(upper));
+
+    let rewritten_expr = match op {
+        // <expr> < x   ==>  <expr> < lower
+        // <expr> >= x  ==>  <expr> >= lower
+        Operator::Lt | Operator::GtEq => binary_expr(*expr, op, lower),
+        // <expr> > x ==> <expr> >= upper
+        Operator::Gt => binary_expr(*expr, Operator::GtEq, upper),
+        // <expr> <= x ==> <expr> < upper
+        Operator::LtEq => binary_expr(*expr, Operator::Lt, upper),
+        // <expr> = x ==> (<expr> >= lower) and (<expr> < upper)
+        //
+        // <expr> is not distinct from x ==> (<expr> is NULL and x is NULL) or 
((<expr> >= lower) and (<expr> < upper))
+        // but since x is always not NULL => (<expr> >= lower) and (<expr> < 
upper)
+        Operator::Eq | Operator::IsNotDistinctFrom => and(
+            binary_expr(*expr.clone(), Operator::GtEq, lower),
+            binary_expr(*expr, Operator::Lt, upper),
+        ),
+        // <expr> != x ==> (<expr> < lower) or (<expr> >= upper)
+        Operator::NotEq => or(
+            binary_expr(*expr.clone(), Operator::Lt, lower),
+            binary_expr(*expr, Operator::GtEq, upper),
+        ),
+        // <expr> is distinct from x ==> (<expr> < lower) or (<expr> >= upper) 
or (<expr> is NULL and x is not NULL) or (<expr> is not NULL and x is NULL)
+        // but given that x is always not NULL => (<expr> < lower) or (<expr> 
>= upper) or (<expr> is NULL)
+        Operator::IsDistinctFrom => binary_expr(*expr.clone(), Operator::Lt, 
lower)
+            .or(binary_expr(*expr.clone(), Operator::GtEq, upper))
+            .or(expr.is_null()),
+        _ => return internal_err!("Expect comparison operators"),
+    };
+    Ok(Transformed::yes(rewritten_expr))
+}
+
+#[cfg(test)]
+mod test {
+    use std::any::Any;
+    use std::sync::Arc;
+
+    use arrow::datatypes::{DataType, Field};
+    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue};
+    use datafusion_expr::{
+        ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, 
ScalarUDFImpl,
+        Signature, Volatility, and, binary_expr, col, expr::ScalarFunction, 
lit,
+        simplify::SimplifyContext,
+    };
+
+    use super::Interval;
+    use crate::simplify_expressions::ExprSimplifier;
+
+    #[derive(Debug, PartialEq, Eq, Hash)]
+    struct PreimageUdf {
+        signature: Signature,
+    }
+
+    impl ScalarUDFImpl for PreimageUdf {
+        fn as_any(&self) -> &dyn Any {
+            self
+        }
+
+        fn name(&self) -> &str {
+            "preimage_func"
+        }
+
+        fn signature(&self) -> &Signature {
+            &self.signature
+        }
+
+        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+            Ok(DataType::Int32)
+        }
+
+        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+            Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(500))))
+        }
+
+        fn preimage(
+            &self,
+            args: &[Expr],
+            lit_expr: &Expr,
+            _info: &SimplifyContext,
+        ) -> Result<Option<Interval>> {
+            if args.len() != 1 {
+                return Ok(None);
+            }
+            match lit_expr {
+                Expr::Literal(ScalarValue::Int32(Some(500)), _) => {
+                    Ok(Some(Interval::try_new(
+                        ScalarValue::Int32(Some(100)),
+                        ScalarValue::Int32(Some(200)),
+                    )?))
+                }
+                _ => Ok(None),
+            }
+        }
+
+        fn column_expr(&self, args: &[Expr]) -> Option<Expr> {
+            args.first().cloned()
+        }
+    }
+
+    fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
+        let simplifier = ExprSimplifier::new(
+            SimplifyContext::default().with_schema(Arc::clone(schema)),
+        );
+
+        simplifier.simplify(expr).unwrap()
+    }
+
+    fn preimage_udf_expr() -> Expr {
+        let udf = ScalarUDF::new_from_impl(PreimageUdf {
+            signature: Signature::exact(vec![DataType::Int32], 
Volatility::Immutable),
+        });
+
+        Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), 
vec![col("x")]))
+    }
+
+    fn test_schema() -> DFSchemaRef {
+        Arc::new(
+            DFSchema::from_unqualified_fields(
+                vec![Field::new("x", DataType::Int32, false)].into(),
+                Default::default(),
+            )
+            .unwrap(),
+        )
+    }
+
+    #[test]
+    fn test_preimage_eq_rewrite() {
+        let schema = test_schema();
+        let expr = binary_expr(preimage_udf_expr(), Operator::Eq, lit(500));
+        let expected = and(
+            binary_expr(col("x"), Operator::GtEq, lit(100)),
+            binary_expr(col("x"), Operator::Lt, lit(200)),
+        );

Review Comment:
   Could you please write these using the more concise strategy? I think it 
would help to review the cases
   
   
   ```suggestion
           let expected = and(
               col("x").gt(lit(100),
               col("x").lt(lit(200))),
           );
   ```



##########
datafusion/optimizer/src/simplify_expressions/udf_preimage.rs:
##########
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::{Result, internal_err, tree_node::Transformed};
+use datafusion_expr::{
+    Expr, Operator, and, binary_expr, lit, or, simplify::SimplifyContext,
+};
+use datafusion_expr_common::interval_arithmetic::Interval;
+
+/// Rewrites a binary expression using its "preimage"
+///
+/// Specifically it rewrites expressions of the form `<expr> OP x` (e.g. 
`<expr> =
+/// x`) where `<expr>` is known to have a pre-image (aka the entire single
+/// range for which it is valid)
+///
+/// This rewrite is described in the [ClickHouse Paper] and is particularly
+/// useful for simplifying expressions `date_part` or equivalent functions. The
+/// idea is that if you have an expression like `date_part(YEAR, k) = 2024` 
and you
+/// can find a [preimage] for `date_part(YEAR, k)`, which is the range of dates
+/// covering the entire year of 2024. Thus, you can rewrite the expression to
+/// `k >= '2024-01-01' AND k < '2025-01-01'`, which uses an inclusive lower 
bound
+/// and exclusive upper bound and is often more optimizable.
+///
+/// [ClickHouse Paper]:  https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf
+/// [preimage]: https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image
+///
+pub(super) fn rewrite_with_preimage(
+    _info: &SimplifyContext,
+    preimage_interval: Interval,
+    op: Operator,
+    expr: Box<Expr>,
+) -> Result<Transformed<Expr>> {
+    let (lower, upper) = preimage_interval.into_bounds();
+    let (lower, upper) = (lit(lower), lit(upper));
+
+    let rewritten_expr = match op {
+        // <expr> < x   ==>  <expr> < lower
+        // <expr> >= x  ==>  <expr> >= lower
+        Operator::Lt | Operator::GtEq => binary_expr(*expr, op, lower),
+        // <expr> > x ==> <expr> >= upper
+        Operator::Gt => binary_expr(*expr, Operator::GtEq, upper),
+        // <expr> <= x ==> <expr> < upper
+        Operator::LtEq => binary_expr(*expr, Operator::Lt, upper),
+        // <expr> = x ==> (<expr> >= lower) and (<expr> < upper)
+        //
+        // <expr> is not distinct from x ==> (<expr> is NULL and x is NULL) or 
((<expr> >= lower) and (<expr> < upper))
+        // but since x is always not NULL => (<expr> >= lower) and (<expr> < 
upper)
+        Operator::Eq | Operator::IsNotDistinctFrom => and(
+            binary_expr(*expr.clone(), Operator::GtEq, lower),
+            binary_expr(*expr, Operator::Lt, upper),
+        ),
+        // <expr> != x ==> (<expr> < lower) or (<expr> >= upper)
+        Operator::NotEq => or(
+            binary_expr(*expr.clone(), Operator::Lt, lower),
+            binary_expr(*expr, Operator::GtEq, upper),
+        ),
+        // <expr> is distinct from x ==> (<expr> < lower) or (<expr> >= upper) 
or (<expr> is NULL and x is not NULL) or (<expr> is not NULL and x is NULL)
+        // but given that x is always not NULL => (<expr> < lower) or (<expr> 
>= upper) or (<expr> is NULL)
+        Operator::IsDistinctFrom => binary_expr(*expr.clone(), Operator::Lt, 
lower)
+            .or(binary_expr(*expr.clone(), Operator::GtEq, upper))
+            .or(expr.is_null()),
+        _ => return internal_err!("Expect comparison operators"),
+    };
+    Ok(Transformed::yes(rewritten_expr))
+}
+
+#[cfg(test)]
+mod test {
+    use std::any::Any;
+    use std::sync::Arc;
+
+    use arrow::datatypes::{DataType, Field};
+    use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue};
+    use datafusion_expr::{
+        ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, 
ScalarUDFImpl,
+        Signature, Volatility, and, binary_expr, col, expr::ScalarFunction, 
lit,
+        simplify::SimplifyContext,
+    };
+
+    use super::Interval;
+    use crate::simplify_expressions::ExprSimplifier;
+
+    #[derive(Debug, PartialEq, Eq, Hash)]
+    struct PreimageUdf {
+        signature: Signature,
+    }
+
+    impl ScalarUDFImpl for PreimageUdf {
+        fn as_any(&self) -> &dyn Any {
+            self
+        }
+
+        fn name(&self) -> &str {
+            "preimage_func"
+        }
+
+        fn signature(&self) -> &Signature {
+            &self.signature
+        }
+
+        fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+            Ok(DataType::Int32)
+        }
+
+        fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+            Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(500))))
+        }
+
+        fn preimage(
+            &self,
+            args: &[Expr],
+            lit_expr: &Expr,
+            _info: &SimplifyContext,
+        ) -> Result<Option<Interval>> {
+            if args.len() != 1 {
+                return Ok(None);
+            }
+            match lit_expr {
+                Expr::Literal(ScalarValue::Int32(Some(500)), _) => {

Review Comment:
   Given this has to check for Expr::Literal anyways, I think the simplfy 
expression could just pass whatever argument in here, rather than only doing it 
with columns and literals



##########
datafusion/expr/src/udf.rs:
##########
@@ -696,6 +715,40 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + 
Sync {
         Ok(ExprSimplifyResult::Original(args))
     }
 
+    /// Returns the [preimage] for this function and the specified scalar 
value, if any.
+    ///
+    /// A preimage is a single contiguous [`Interval`] of values where the 
function
+    /// will always return `lit_value`
+    ///
+    /// Implementations should return intervals with an inclusive lower bound 
and
+    /// exclusive upper bound.
+    ///
+    /// This rewrite is described in the [ClickHouse Paper] and is particularly
+    /// useful for simplifying expressions `date_part` or equivalent 
functions. The
+    /// idea is that if you have an expression like `date_part(YEAR, k) = 
2024` and you
+    /// can find a [preimage] for `date_part(YEAR, k)`, which is the range of 
dates
+    /// covering the entire year of 2024. Thus, you can rewrite the expression 
to `k
+    /// >= '2024-01-01' AND k < '2025-01-01' which is often more optimizable.
+    ///
+    /// Implementations must also provide [`ScalarUDFImpl::column_expr`] so the
+    /// optimizer can identify which argument maps to the preimage interval.
+    ///
+    /// [ClickHouse Paper]:  https://www.vldb.org/pvldb/vol17/p3731-schulze.pdf
+    /// [preimage]: 
https://en.wikipedia.org/wiki/Image_(mathematics)#Inverse_image
+    fn preimage(
+        &self,
+        _args: &[Expr],
+        _lit_expr: &Expr,
+        _info: &SimplifyContext,
+    ) -> Result<Option<Interval>> {
+        Ok(None)
+    }
+
+    // Return the inner column expression from this function

Review Comment:
   I see you need this to get the argument, but it seems very error prone, as 
it 
   1. requires that `preimage` and this function return consistent results
   2. it is not clear how this would work when the inner expression was not a 
single column expression (e.g. `date_part(YEAR, '2025-01-02'::year')`)
   
   
   What if we instead returned the inner expression as part of the call to 
`preimage` itself? That would be consistent with what `get_preimage` does below 
anyways
   
   So that would mean something like
   
   
   ```rust
       /// A preimage is a single contiguous [`Interval`] of values where the 
function
       /// will always return `lit_value`. If a pre-image exists, returns the 
interval
       /// and the expressionf or which this holds
       /// 
       /// For example, `date_trunc(YEAR, ts_col)` when called with 
`preimage(2025)` would return
       /// `(ts_col, (2025-01-01 - 2026-01-01)`
       fn preimage(
           &self,
           _args: &[Expr],
           _lit_expr: &Expr,
           _info: &SimplifyContext,
         ) -> Result<Option<(Expr, Interval)>> {
             Ok(None)
        }
   ```
   
   Bonus points for wrapping it in a enum, consistent with how 
`ExprSimplifyResult` works: 
https://github.com/apache/datafusion/blob/35ff4ab0a03fcc6615876eac76bac19887059ab3/datafusion/expr/src/simplify.rs#L115-L121
   
   ```rust
   pub enum PreimageResult {
     /// No preimage exists for the specified value
     None, 
     /// The expression always evaluates to the specified constant
     /// given that `expr` is within the interval
     Range { 
       expr: Expr,
       interval: Interval
     }
   }
   
   
   ...
       fn preimage(
           &self,
           _args: &[Expr],
           _lit_expr: &Expr,
           _info: &SimplifyContext,
         ) -> Result<PreimageResult> {
             Ok(PreimageResult::None)
        }
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to