This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1b10c9f89e Make PruningPredicate's rewrite public (#12850)
1b10c9f89e is described below

commit 1b10c9f89eac127507fe7a137ff7c40534f7ca9a
Author: Adrian Garcia Badaracco <[email protected]>
AuthorDate: Sun Oct 13 07:26:42 2024 -0500

    Make PruningPredicate's rewrite public (#12850)
    
    * Make PruningPredicate's rewrite public
    
    * feedback
    
    * Improve documentation and add default to ConstantUnhandledPredicatehook
    
    * Update pruning.rs
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/physical_optimizer/pruning.rs | 212 +++++++++++++++++++---
 1 file changed, 188 insertions(+), 24 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/pruning.rs 
b/datafusion/core/src/physical_optimizer/pruning.rs
index 9bc2bb1d1d..eb03b33777 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -458,7 +458,7 @@ pub trait PruningStatistics {
 /// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741
 /// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf
 /// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10
-///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515
+/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515
 #[derive(Debug, Clone)]
 pub struct PruningPredicate {
     /// The input schema against which the predicate will be evaluated
@@ -478,6 +478,36 @@ pub struct PruningPredicate {
     literal_guarantees: Vec<LiteralGuarantee>,
 }
 
+/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain
+/// complex expressions or predicates that reference columns that are not in 
the
+/// schema.
+pub trait UnhandledPredicateHook {
+    /// Called when a predicate can not be rewritten in terms of statistics or
+    /// references a column that is not in the schema.
+    fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
+}
+
+/// The default handling for unhandled predicates is to return a constant 
`true`
+/// (meaning don't prune the container)
+#[derive(Debug, Clone)]
+struct ConstantUnhandledPredicateHook {
+    default: Arc<dyn PhysicalExpr>,
+}
+
+impl Default for ConstantUnhandledPredicateHook {
+    fn default() -> Self {
+        Self {
+            default: 
Arc::new(phys_expr::Literal::new(ScalarValue::from(true))),
+        }
+    }
+}
+
+impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
+    fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
+        self.default.clone()
+    }
+}
+
 impl PruningPredicate {
     /// Try to create a new instance of [`PruningPredicate`]
     ///
@@ -502,10 +532,16 @@ impl PruningPredicate {
     /// See the struct level documentation on [`PruningPredicate`] for more
     /// details.
     pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) -> 
Result<Self> {
+        let unhandled_hook = 
Arc::new(ConstantUnhandledPredicateHook::default()) as _;
+
         // build predicate expression once
         let mut required_columns = RequiredColumns::new();
-        let predicate_expr =
-            build_predicate_expression(&expr, schema.as_ref(), &mut 
required_columns);
+        let predicate_expr = build_predicate_expression(
+            &expr,
+            schema.as_ref(),
+            &mut required_columns,
+            &unhandled_hook,
+        );
 
         let literal_guarantees = LiteralGuarantee::analyze(&expr);
 
@@ -1312,27 +1348,78 @@ fn build_is_null_column_expr(
 /// an OR chain
 const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20;
 
+/// Rewrite a predicate expression in terms of statistics (min/max/null_counts)
+/// for use as a [`PruningPredicate`].
+pub struct PredicateRewriter {
+    unhandled_hook: Arc<dyn UnhandledPredicateHook>,
+}
+
+impl Default for PredicateRewriter {
+    fn default() -> Self {
+        Self {
+            unhandled_hook: 
Arc::new(ConstantUnhandledPredicateHook::default()),
+        }
+    }
+}
+
+impl PredicateRewriter {
+    /// Create a new `PredicateRewriter`
+    pub fn new() -> Self {
+        Self::default()
+    }
+
+    /// Set the unhandled hook to be used when a predicate can not be rewritten
+    pub fn with_unhandled_hook(
+        self,
+        unhandled_hook: Arc<dyn UnhandledPredicateHook>,
+    ) -> Self {
+        Self { unhandled_hook }
+    }
+
+    /// Translate logical filter expression into pruning predicate
+    /// expression that will evaluate to FALSE if it can be determined no
+    /// rows between the min/max values could pass the predicates.
+    ///
+    /// Any predicates that can not be translated will be passed to 
`unhandled_hook`.
+    ///
+    /// Returns the pruning predicate as an [`PhysicalExpr`]
+    ///
+    /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, 
which will fall back to calling `unhandled_hook`
+    pub fn rewrite_predicate_to_statistics_predicate(
+        &self,
+        expr: &Arc<dyn PhysicalExpr>,
+        schema: &Schema,
+    ) -> Arc<dyn PhysicalExpr> {
+        let mut required_columns = RequiredColumns::new();
+        build_predicate_expression(
+            expr,
+            schema,
+            &mut required_columns,
+            &self.unhandled_hook,
+        )
+    }
+}
+
 /// Translate logical filter expression into pruning predicate
 /// expression that will evaluate to FALSE if it can be determined no
 /// rows between the min/max values could pass the predicates.
 ///
+/// Any predicates that can not be translated will be passed to 
`unhandled_hook`.
+///
 /// Returns the pruning predicate as an [`PhysicalExpr`]
 ///
-/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which 
will be rewritten to TRUE
+/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which 
will fall back to calling `unhandled_hook`
 fn build_predicate_expression(
     expr: &Arc<dyn PhysicalExpr>,
     schema: &Schema,
     required_columns: &mut RequiredColumns,
+    unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
 ) -> Arc<dyn PhysicalExpr> {
-    // Returned for unsupported expressions. Such expressions are
-    // converted to TRUE.
-    let unhandled = 
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));
-
     // predicate expression can only be a binary expression
     let expr_any = expr.as_any();
     if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
         return build_is_null_column_expr(is_null.arg(), schema, 
required_columns, false)
-            .unwrap_or(unhandled);
+            .unwrap_or_else(|| unhandled_hook.handle(expr));
     }
     if let Some(is_not_null) = 
expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
         return build_is_null_column_expr(
@@ -1341,19 +1428,19 @@ fn build_predicate_expression(
             required_columns,
             true,
         )
-        .unwrap_or(unhandled);
+        .unwrap_or_else(|| unhandled_hook.handle(expr));
     }
     if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
         return build_single_column_expr(col, schema, required_columns, false)
-            .unwrap_or(unhandled);
+            .unwrap_or_else(|| unhandled_hook.handle(expr));
     }
     if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
         // match !col (don't do so recursively)
         if let Some(col) = 
not.arg().as_any().downcast_ref::<phys_expr::Column>() {
             return build_single_column_expr(col, schema, required_columns, 
true)
-                .unwrap_or(unhandled);
+                .unwrap_or_else(|| unhandled_hook.handle(expr));
         } else {
-            return unhandled;
+            return unhandled_hook.handle(expr);
         }
     }
     if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
@@ -1382,9 +1469,14 @@ fn build_predicate_expression(
                 })
                 .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, 
b)) as _)
                 .unwrap();
-            return build_predicate_expression(&change_expr, schema, 
required_columns);
+            return build_predicate_expression(
+                &change_expr,
+                schema,
+                required_columns,
+                unhandled_hook,
+            );
         } else {
-            return unhandled;
+            return unhandled_hook.handle(expr);
         }
     }
 
@@ -1396,13 +1488,15 @@ fn build_predicate_expression(
                 bin_expr.right().clone(),
             )
         } else {
-            return unhandled;
+            return unhandled_hook.handle(expr);
         }
     };
 
     if op == Operator::And || op == Operator::Or {
-        let left_expr = build_predicate_expression(&left, schema, 
required_columns);
-        let right_expr = build_predicate_expression(&right, schema, 
required_columns);
+        let left_expr =
+            build_predicate_expression(&left, schema, required_columns, 
unhandled_hook);
+        let right_expr =
+            build_predicate_expression(&right, schema, required_columns, 
unhandled_hook);
         // simplify boolean expression if applicable
         let expr = match (&left_expr, op, &right_expr) {
             (left, Operator::And, _) if is_always_true(left) => right_expr,
@@ -1410,7 +1504,7 @@ fn build_predicate_expression(
             (left, Operator::Or, right)
                 if is_always_true(left) || is_always_true(right) =>
             {
-                unhandled
+                
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
             }
             _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, 
right_expr)),
         };
@@ -1423,12 +1517,11 @@ fn build_predicate_expression(
         Ok(builder) => builder,
         // allow partial failure in predicate expression generation
         // this can still produce a useful predicate when multiple conditions 
are joined using AND
-        Err(_) => {
-            return unhandled;
-        }
+        Err(_) => return unhandled_hook.handle(expr),
     };
 
-    build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
+    build_statistics_expr(&mut expr_builder)
+        .unwrap_or_else(|_| unhandled_hook.handle(expr))
 }
 
 fn build_statistics_expr(
@@ -1582,6 +1675,8 @@ mod tests {
     use arrow_array::UInt64Array;
     use datafusion_expr::expr::InList;
     use datafusion_expr::{cast, is_null, try_cast, Expr};
+    use datafusion_functions_nested::expr_fn::{array_has, make_array};
+    use datafusion_physical_expr::expressions as phys_expr;
     use datafusion_physical_expr::planner::logical2physical;
 
     #[derive(Debug, Default)]
@@ -3397,6 +3492,74 @@ mod tests {
         // TODO: add test for other case and op
     }
 
+    #[test]
+    fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
+        struct CustomUnhandledHook;
+
+        impl UnhandledPredicateHook for CustomUnhandledHook {
+            /// This handles an arbitrary case of a column that doesn't exist 
in the schema
+            /// by renaming it to yet another column that doesn't exist in the 
schema
+            /// (the transformation is arbitrary, the point is that it can do 
whatever it wants)
+            fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn 
PhysicalExpr> {
+                Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42))))
+            }
+        }
+
+        let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
+        let schema_with_b = Schema::new(vec![
+            Field::new("a", DataType::Int32, true),
+            Field::new("b", DataType::Int32, true),
+        ]);
+
+        let rewriter = PredicateRewriter::new()
+            .with_unhandled_hook(Arc::new(CustomUnhandledHook {}));
+
+        let transform_expr = |expr| {
+            let expr = logical2physical(&expr, &schema_with_b);
+            rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema)
+        };
+
+        // transform an arbitrary valid expression that we know is handled
+        let known_expression = col("a").eq(lit(12));
+        let known_expression_transformed = PredicateRewriter::new()
+            .rewrite_predicate_to_statistics_predicate(
+                &logical2physical(&known_expression, &schema),
+                &schema,
+            );
+
+        // an expression referencing an unknown column (that is not in the 
schema) gets passed to the hook
+        let input = col("b").eq(lit(12));
+        let expected = logical2physical(&lit(42), &schema);
+        let transformed = transform_expr(input.clone());
+        assert_eq!(transformed.to_string(), expected.to_string());
+
+        // more complex case with unknown column
+        let input = known_expression.clone().and(input.clone());
+        let expected = phys_expr::BinaryExpr::new(
+            known_expression_transformed.clone(),
+            Operator::And,
+            logical2physical(&lit(42), &schema),
+        );
+        let transformed = transform_expr(input.clone());
+        assert_eq!(transformed.to_string(), expected.to_string());
+
+        // an unknown expression gets passed to the hook
+        let input = array_has(make_array(vec![lit(1)]), col("a"));
+        let expected = logical2physical(&lit(42), &schema);
+        let transformed = transform_expr(input.clone());
+        assert_eq!(transformed.to_string(), expected.to_string());
+
+        // more complex case with unknown expression
+        let input = known_expression.and(input);
+        let expected = phys_expr::BinaryExpr::new(
+            known_expression_transformed.clone(),
+            Operator::And,
+            logical2physical(&lit(42), &schema),
+        );
+        let transformed = transform_expr(input.clone());
+        assert_eq!(transformed.to_string(), expected.to_string());
+    }
+
     #[test]
     fn test_rewrite_expr_to_prunable_error() {
         // cast string value to numeric value
@@ -3886,6 +4049,7 @@ mod tests {
         required_columns: &mut RequiredColumns,
     ) -> Arc<dyn PhysicalExpr> {
         let expr = logical2physical(expr, schema);
-        build_predicate_expression(&expr, schema, required_columns)
+        let unhandled_hook = 
Arc::new(ConstantUnhandledPredicateHook::default()) as _;
+        build_predicate_expression(&expr, schema, required_columns, 
&unhandled_hook)
     }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to