This is an automated email from the ASF dual-hosted git repository.
thinkharderdev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new c477fc0ca Fix filter pushdown for extension plans (#5425)
c477fc0ca is described below
commit c477fc0ca991cfb84d8d6f082879e969761d9125
Author: Dan Harris <[email protected]>
AuthorDate: Tue Feb 28 08:02:26 2023 -0500
Fix filter pushdown for extension plans (#5425)
* Fix filter pushdown for extension plans
* Update datafusion/optimizer/src/push_down_filter.rs
Co-authored-by: jakevin <[email protected]>
---------
Co-authored-by: jakevin <[email protected]>
---
datafusion/optimizer/src/push_down_filter.rs | 174 ++++++++++++++++++++++++++-
1 file changed, 172 insertions(+), 2 deletions(-)
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index 6d8db2043..0d8da5573 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -735,6 +735,48 @@ impl OptimizerRule for PushDownFilter {
None => new_scan,
}
}
+ LogicalPlan::Extension(extension_plan) => {
+ let prevent_cols =
+ extension_plan.node.prevent_predicate_push_down_columns();
+
+ let predicates =
utils::split_conjunction_owned(filter.predicate.clone());
+
+ let mut keep_predicates = vec![];
+ let mut push_predicates = vec![];
+ for expr in predicates {
+ let cols = expr.to_columns()?;
+ if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
+ keep_predicates.push(expr);
+ } else {
+ push_predicates.push(expr);
+ }
+ }
+
+ let new_children = match conjunction(push_predicates) {
+ Some(predicate) => extension_plan
+ .node
+ .inputs()
+ .into_iter()
+ .map(|child| {
+ Ok(LogicalPlan::Filter(Filter::try_new(
+ predicate.clone(),
+ Arc::new(child.clone()),
+ )?))
+ })
+ .collect::<Result<Vec<_>>>()?,
+ None =>
extension_plan.node.inputs().into_iter().cloned().collect(),
+ };
+ // extension with new inputs.
+ let new_extension = child_plan.with_new_inputs(&new_children)?;
+
+ match conjunction(keep_predicates) {
+ Some(predicate) => LogicalPlan::Filter(Filter::try_new(
+ predicate,
+ Arc::new(new_extension),
+ )?),
+ None => new_extension,
+ }
+ }
_ => return Ok(None),
};
Ok(Some(new_plan))
@@ -774,12 +816,15 @@ mod tests {
use crate::OptimizerContext;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
- use datafusion_common::DFSchema;
+ use datafusion_common::{DFSchema, DFSchemaRef};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum,
BinaryExpr,
- Expr, LogicalPlanBuilder, Operator, TableSource, TableType,
+ Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType,
+ UserDefinedLogicalNode,
};
+ use std::any::Any;
+ use std::fmt::{Debug, Formatter};
use std::sync::Arc;
fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) ->
Result<()> {
@@ -1029,6 +1074,131 @@ mod tests {
assert_optimized_plan_eq(&plan, expected)
}
+ #[derive(Debug)]
+ struct NoopPlan {
+ input: Vec<LogicalPlan>,
+ schema: DFSchemaRef,
+ }
+
+ impl UserDefinedLogicalNode for NoopPlan {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn inputs(&self) -> Vec<&LogicalPlan> {
+ self.input.iter().collect()
+ }
+
+ fn schema(&self) -> &DFSchemaRef {
+ &self.schema
+ }
+
+ fn expressions(&self) -> Vec<Expr> {
+ self.input
+ .iter()
+ .flat_map(|child| child.expressions())
+ .collect()
+ }
+
+ fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
+ HashSet::from_iter(vec!["c".to_string()])
+ }
+
+ fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
+ write!(f, "NoopPlan")
+ }
+
+ fn from_template(
+ &self,
+ _exprs: &[Expr],
+ inputs: &[LogicalPlan],
+ ) -> Arc<dyn UserDefinedLogicalNode> {
+ Arc::new(Self {
+ input: inputs.to_vec(),
+ schema: self.schema.clone(),
+ })
+ }
+ }
+
+ #[test]
+ fn user_defined_plan() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let custom_plan = LogicalPlan::Extension(Extension {
+ node: Arc::new(NoopPlan {
+ input: vec![table_scan.clone()],
+ schema: table_scan.schema().clone(),
+ }),
+ });
+ let plan = LogicalPlanBuilder::from(custom_plan)
+ .filter(col("a").eq(lit(1i64)))?
+ .build()?;
+
+ // Push filter below NoopPlan
+ let expected = "\
+ NoopPlan\
+ \n Filter: test.a = Int64(1)\
+ \n TableScan: test";
+ assert_optimized_plan_eq(&plan, expected)?;
+
+ let custom_plan = LogicalPlan::Extension(Extension {
+ node: Arc::new(NoopPlan {
+ input: vec![table_scan.clone()],
+ schema: table_scan.schema().clone(),
+ }),
+ });
+ let plan = LogicalPlanBuilder::from(custom_plan)
+ .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
+ .build()?;
+
+ // Push only predicate on `a` below NoopPlan
+ let expected = "\
+ Filter: test.c = Int64(2)\
+ \n NoopPlan\
+ \n Filter: test.a = Int64(1)\
+ \n TableScan: test";
+ assert_optimized_plan_eq(&plan, expected)?;
+
+ let custom_plan = LogicalPlan::Extension(Extension {
+ node: Arc::new(NoopPlan {
+ input: vec![table_scan.clone(), table_scan.clone()],
+ schema: table_scan.schema().clone(),
+ }),
+ });
+ let plan = LogicalPlanBuilder::from(custom_plan)
+ .filter(col("a").eq(lit(1i64)))?
+ .build()?;
+
+ // Push filter below NoopPlan for each child branch
+ let expected = "\
+ NoopPlan\
+ \n Filter: test.a = Int64(1)\
+ \n TableScan: test\
+ \n Filter: test.a = Int64(1)\
+ \n TableScan: test";
+ assert_optimized_plan_eq(&plan, expected)?;
+
+ let custom_plan = LogicalPlan::Extension(Extension {
+ node: Arc::new(NoopPlan {
+ input: vec![table_scan.clone(), table_scan.clone()],
+ schema: table_scan.schema().clone(),
+ }),
+ });
+ let plan = LogicalPlanBuilder::from(custom_plan)
+ .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
+ .build()?;
+
+ // Push only predicate on `a` below NoopPlan
+ let expected = "\
+ Filter: test.c = Int64(2)\
+ \n NoopPlan\
+ \n Filter: test.a = Int64(1)\
+ \n TableScan: test\
+ \n Filter: test.a = Int64(1)\
+ \n TableScan: test";
+ assert_optimized_plan_eq(&plan, expected)
+ }
+
/// verifies that when two filters apply after an aggregation that only
allows one to be pushed, one is pushed
/// and the other not.
#[test]