This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1b10c9f89e Make PruningPredicate's rewrite public (#12850)
1b10c9f89e is described below
commit 1b10c9f89eac127507fe7a137ff7c40534f7ca9a
Author: Adrian Garcia Badaracco <[email protected]>
AuthorDate: Sun Oct 13 07:26:42 2024 -0500
Make PruningPredicate's rewrite public (#12850)
* Make PruningPredicate's rewrite public
* feedback
* Improve documentation and add default to ConstantUnhandledPredicatehook
* Update pruning.rs
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/physical_optimizer/pruning.rs | 212 +++++++++++++++++++---
1 file changed, 188 insertions(+), 24 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/pruning.rs
b/datafusion/core/src/physical_optimizer/pruning.rs
index 9bc2bb1d1d..eb03b33777 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -458,7 +458,7 @@ pub trait PruningStatistics {
/// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741
/// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf
/// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10
-///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515
+/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515
#[derive(Debug, Clone)]
pub struct PruningPredicate {
/// The input schema against which the predicate will be evaluated
@@ -478,6 +478,36 @@ pub struct PruningPredicate {
literal_guarantees: Vec<LiteralGuarantee>,
}
+/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain
+/// complex expressions or predicates that reference columns that are not in
the
+/// schema.
+pub trait UnhandledPredicateHook {
+ /// Called when a predicate can not be rewritten in terms of statistics or
+ /// references a column that is not in the schema.
+ fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
+}
+
+/// The default handling for unhandled predicates is to return a constant
`true`
+/// (meaning don't prune the container)
+#[derive(Debug, Clone)]
+struct ConstantUnhandledPredicateHook {
+ default: Arc<dyn PhysicalExpr>,
+}
+
+impl Default for ConstantUnhandledPredicateHook {
+ fn default() -> Self {
+ Self {
+ default:
Arc::new(phys_expr::Literal::new(ScalarValue::from(true))),
+ }
+ }
+}
+
+impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
+ fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
+ self.default.clone()
+ }
+}
+
impl PruningPredicate {
/// Try to create a new instance of [`PruningPredicate`]
///
@@ -502,10 +532,16 @@ impl PruningPredicate {
/// See the struct level documentation on [`PruningPredicate`] for more
/// details.
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) ->
Result<Self> {
+ let unhandled_hook =
Arc::new(ConstantUnhandledPredicateHook::default()) as _;
+
// build predicate expression once
let mut required_columns = RequiredColumns::new();
- let predicate_expr =
- build_predicate_expression(&expr, schema.as_ref(), &mut
required_columns);
+ let predicate_expr = build_predicate_expression(
+ &expr,
+ schema.as_ref(),
+ &mut required_columns,
+ &unhandled_hook,
+ );
let literal_guarantees = LiteralGuarantee::analyze(&expr);
@@ -1312,27 +1348,78 @@ fn build_is_null_column_expr(
/// an OR chain
const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20;
+/// Rewrite a predicate expression in terms of statistics (min/max/null_counts)
+/// for use as a [`PruningPredicate`].
+pub struct PredicateRewriter {
+ unhandled_hook: Arc<dyn UnhandledPredicateHook>,
+}
+
+impl Default for PredicateRewriter {
+ fn default() -> Self {
+ Self {
+ unhandled_hook:
Arc::new(ConstantUnhandledPredicateHook::default()),
+ }
+ }
+}
+
+impl PredicateRewriter {
+ /// Create a new `PredicateRewriter`
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ /// Set the unhandled hook to be used when a predicate can not be rewritten
+ pub fn with_unhandled_hook(
+ self,
+ unhandled_hook: Arc<dyn UnhandledPredicateHook>,
+ ) -> Self {
+ Self { unhandled_hook }
+ }
+
+ /// Translate logical filter expression into pruning predicate
+ /// expression that will evaluate to FALSE if it can be determined no
+ /// rows between the min/max values could pass the predicates.
+ ///
+ /// Any predicates that can not be translated will be passed to
`unhandled_hook`.
+ ///
+ /// Returns the pruning predicate as an [`PhysicalExpr`]
+ ///
+ /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20,
which will fall back to calling `unhandled_hook`
+ pub fn rewrite_predicate_to_statistics_predicate(
+ &self,
+ expr: &Arc<dyn PhysicalExpr>,
+ schema: &Schema,
+ ) -> Arc<dyn PhysicalExpr> {
+ let mut required_columns = RequiredColumns::new();
+ build_predicate_expression(
+ expr,
+ schema,
+ &mut required_columns,
+ &self.unhandled_hook,
+ )
+ }
+}
+
/// Translate logical filter expression into pruning predicate
/// expression that will evaluate to FALSE if it can be determined no
/// rows between the min/max values could pass the predicates.
///
+/// Any predicates that can not be translated will be passed to
`unhandled_hook`.
+///
/// Returns the pruning predicate as an [`PhysicalExpr`]
///
-/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which
will be rewritten to TRUE
+/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which
will fall back to calling `unhandled_hook`
fn build_predicate_expression(
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
required_columns: &mut RequiredColumns,
+ unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
) -> Arc<dyn PhysicalExpr> {
- // Returned for unsupported expressions. Such expressions are
- // converted to TRUE.
- let unhandled =
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));
-
// predicate expression can only be a binary expression
let expr_any = expr.as_any();
if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
return build_is_null_column_expr(is_null.arg(), schema,
required_columns, false)
- .unwrap_or(unhandled);
+ .unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(is_not_null) =
expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
return build_is_null_column_expr(
@@ -1341,19 +1428,19 @@ fn build_predicate_expression(
required_columns,
true,
)
- .unwrap_or(unhandled);
+ .unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, false)
- .unwrap_or(unhandled);
+ .unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
// match !col (don't do so recursively)
if let Some(col) =
not.arg().as_any().downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns,
true)
- .unwrap_or(unhandled);
+ .unwrap_or_else(|| unhandled_hook.handle(expr));
} else {
- return unhandled;
+ return unhandled_hook.handle(expr);
}
}
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
@@ -1382,9 +1469,14 @@ fn build_predicate_expression(
})
.reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op,
b)) as _)
.unwrap();
- return build_predicate_expression(&change_expr, schema,
required_columns);
+ return build_predicate_expression(
+ &change_expr,
+ schema,
+ required_columns,
+ unhandled_hook,
+ );
} else {
- return unhandled;
+ return unhandled_hook.handle(expr);
}
}
@@ -1396,13 +1488,15 @@ fn build_predicate_expression(
bin_expr.right().clone(),
)
} else {
- return unhandled;
+ return unhandled_hook.handle(expr);
}
};
if op == Operator::And || op == Operator::Or {
- let left_expr = build_predicate_expression(&left, schema,
required_columns);
- let right_expr = build_predicate_expression(&right, schema,
required_columns);
+ let left_expr =
+ build_predicate_expression(&left, schema, required_columns,
unhandled_hook);
+ let right_expr =
+ build_predicate_expression(&right, schema, required_columns,
unhandled_hook);
// simplify boolean expression if applicable
let expr = match (&left_expr, op, &right_expr) {
(left, Operator::And, _) if is_always_true(left) => right_expr,
@@ -1410,7 +1504,7 @@ fn build_predicate_expression(
(left, Operator::Or, right)
if is_always_true(left) || is_always_true(right) =>
{
- unhandled
+
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
}
_ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op,
right_expr)),
};
@@ -1423,12 +1517,11 @@ fn build_predicate_expression(
Ok(builder) => builder,
// allow partial failure in predicate expression generation
// this can still produce a useful predicate when multiple conditions
are joined using AND
- Err(_) => {
- return unhandled;
- }
+ Err(_) => return unhandled_hook.handle(expr),
};
- build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
+ build_statistics_expr(&mut expr_builder)
+ .unwrap_or_else(|_| unhandled_hook.handle(expr))
}
fn build_statistics_expr(
@@ -1582,6 +1675,8 @@ mod tests {
use arrow_array::UInt64Array;
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
+ use datafusion_functions_nested::expr_fn::{array_has, make_array};
+ use datafusion_physical_expr::expressions as phys_expr;
use datafusion_physical_expr::planner::logical2physical;
#[derive(Debug, Default)]
@@ -3397,6 +3492,74 @@ mod tests {
// TODO: add test for other case and op
}
+ #[test]
+ fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
+ struct CustomUnhandledHook;
+
+ impl UnhandledPredicateHook for CustomUnhandledHook {
+ /// This handles an arbitrary case of a column that doesn't exist
in the schema
+ /// by renaming it to yet another column that doesn't exist in the
schema
+ /// (the transformation is arbitrary, the point is that it can do
whatever it wants)
+ fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn
PhysicalExpr> {
+ Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42))))
+ }
+ }
+
+ let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
+ let schema_with_b = Schema::new(vec![
+ Field::new("a", DataType::Int32, true),
+ Field::new("b", DataType::Int32, true),
+ ]);
+
+ let rewriter = PredicateRewriter::new()
+ .with_unhandled_hook(Arc::new(CustomUnhandledHook {}));
+
+ let transform_expr = |expr| {
+ let expr = logical2physical(&expr, &schema_with_b);
+ rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema)
+ };
+
+ // transform an arbitrary valid expression that we know is handled
+ let known_expression = col("a").eq(lit(12));
+ let known_expression_transformed = PredicateRewriter::new()
+ .rewrite_predicate_to_statistics_predicate(
+ &logical2physical(&known_expression, &schema),
+ &schema,
+ );
+
+ // an expression referencing an unknown column (that is not in the
schema) gets passed to the hook
+ let input = col("b").eq(lit(12));
+ let expected = logical2physical(&lit(42), &schema);
+ let transformed = transform_expr(input.clone());
+ assert_eq!(transformed.to_string(), expected.to_string());
+
+ // more complex case with unknown column
+ let input = known_expression.clone().and(input.clone());
+ let expected = phys_expr::BinaryExpr::new(
+ known_expression_transformed.clone(),
+ Operator::And,
+ logical2physical(&lit(42), &schema),
+ );
+ let transformed = transform_expr(input.clone());
+ assert_eq!(transformed.to_string(), expected.to_string());
+
+ // an unknown expression gets passed to the hook
+ let input = array_has(make_array(vec![lit(1)]), col("a"));
+ let expected = logical2physical(&lit(42), &schema);
+ let transformed = transform_expr(input.clone());
+ assert_eq!(transformed.to_string(), expected.to_string());
+
+ // more complex case with unknown expression
+ let input = known_expression.and(input);
+ let expected = phys_expr::BinaryExpr::new(
+ known_expression_transformed.clone(),
+ Operator::And,
+ logical2physical(&lit(42), &schema),
+ );
+ let transformed = transform_expr(input.clone());
+ assert_eq!(transformed.to_string(), expected.to_string());
+ }
+
#[test]
fn test_rewrite_expr_to_prunable_error() {
// cast string value to numeric value
@@ -3886,6 +4049,7 @@ mod tests {
required_columns: &mut RequiredColumns,
) -> Arc<dyn PhysicalExpr> {
let expr = logical2physical(expr, schema);
- build_predicate_expression(&expr, schema, required_columns)
+ let unhandled_hook =
Arc::new(ConstantUnhandledPredicateHook::default()) as _;
+ build_predicate_expression(&expr, schema, required_columns,
&unhandled_hook)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]