This is an automated email from the ASF dual-hosted git repository.

thinkharderdev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c477fc0ca Fix filter pushdown for extension plans (#5425)
c477fc0ca is described below

commit c477fc0ca991cfb84d8d6f082879e969761d9125
Author: Dan Harris <[email protected]>
AuthorDate: Tue Feb 28 08:02:26 2023 -0500

    Fix filter pushdown for extension plans (#5425)
    
    * Fix filter pushdown for extension plans
    
    * Update datafusion/optimizer/src/push_down_filter.rs
    
    Co-authored-by: jakevin <[email protected]>
    
    ---------
    
    Co-authored-by: jakevin <[email protected]>
---
 datafusion/optimizer/src/push_down_filter.rs | 174 ++++++++++++++++++++++++++-
 1 file changed, 172 insertions(+), 2 deletions(-)

diff --git a/datafusion/optimizer/src/push_down_filter.rs 
b/datafusion/optimizer/src/push_down_filter.rs
index 6d8db2043..0d8da5573 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -735,6 +735,48 @@ impl OptimizerRule for PushDownFilter {
                     None => new_scan,
                 }
             }
+            LogicalPlan::Extension(extension_plan) => {
+                let prevent_cols =
+                    extension_plan.node.prevent_predicate_push_down_columns();
+
+                let predicates = 
utils::split_conjunction_owned(filter.predicate.clone());
+
+                let mut keep_predicates = vec![];
+                let mut push_predicates = vec![];
+                for expr in predicates {
+                    let cols = expr.to_columns()?;
+                    if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
+                        keep_predicates.push(expr);
+                    } else {
+                        push_predicates.push(expr);
+                    }
+                }
+
+                let new_children = match conjunction(push_predicates) {
+                    Some(predicate) => extension_plan
+                        .node
+                        .inputs()
+                        .into_iter()
+                        .map(|child| {
+                            Ok(LogicalPlan::Filter(Filter::try_new(
+                                predicate.clone(),
+                                Arc::new(child.clone()),
+                            )?))
+                        })
+                        .collect::<Result<Vec<_>>>()?,
+                    None => 
extension_plan.node.inputs().into_iter().cloned().collect(),
+                };
+                // extension with new inputs.
+                let new_extension = child_plan.with_new_inputs(&new_children)?;
+
+                match conjunction(keep_predicates) {
+                    Some(predicate) => LogicalPlan::Filter(Filter::try_new(
+                        predicate,
+                        Arc::new(new_extension),
+                    )?),
+                    None => new_extension,
+                }
+            }
             _ => return Ok(None),
         };
         Ok(Some(new_plan))
@@ -774,12 +816,15 @@ mod tests {
     use crate::OptimizerContext;
     use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
     use async_trait::async_trait;
-    use datafusion_common::DFSchema;
+    use datafusion_common::{DFSchema, DFSchemaRef};
     use datafusion_expr::logical_plan::table_scan;
     use datafusion_expr::{
         and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, 
BinaryExpr,
-        Expr, LogicalPlanBuilder, Operator, TableSource, TableType,
+        Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType,
+        UserDefinedLogicalNode,
     };
+    use std::any::Any;
+    use std::fmt::{Debug, Formatter};
     use std::sync::Arc;
 
     fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> 
Result<()> {
@@ -1029,6 +1074,131 @@ mod tests {
         assert_optimized_plan_eq(&plan, expected)
     }
 
+    #[derive(Debug)]
+    struct NoopPlan {
+        input: Vec<LogicalPlan>,
+        schema: DFSchemaRef,
+    }
+
+    impl UserDefinedLogicalNode for NoopPlan {
+        fn as_any(&self) -> &dyn Any {
+            self
+        }
+
+        fn inputs(&self) -> Vec<&LogicalPlan> {
+            self.input.iter().collect()
+        }
+
+        fn schema(&self) -> &DFSchemaRef {
+            &self.schema
+        }
+
+        fn expressions(&self) -> Vec<Expr> {
+            self.input
+                .iter()
+                .flat_map(|child| child.expressions())
+                .collect()
+        }
+
+        fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
+            HashSet::from_iter(vec!["c".to_string()])
+        }
+
+        fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
+            write!(f, "NoopPlan")
+        }
+
+        fn from_template(
+            &self,
+            _exprs: &[Expr],
+            inputs: &[LogicalPlan],
+        ) -> Arc<dyn UserDefinedLogicalNode> {
+            Arc::new(Self {
+                input: inputs.to_vec(),
+                schema: self.schema.clone(),
+            })
+        }
+    }
+
+    #[test]
+    fn user_defined_plan() -> Result<()> {
+        let table_scan = test_table_scan()?;
+
+        let custom_plan = LogicalPlan::Extension(Extension {
+            node: Arc::new(NoopPlan {
+                input: vec![table_scan.clone()],
+                schema: table_scan.schema().clone(),
+            }),
+        });
+        let plan = LogicalPlanBuilder::from(custom_plan)
+            .filter(col("a").eq(lit(1i64)))?
+            .build()?;
+
+        // Push filter below NoopPlan
+        let expected = "\
+            NoopPlan\
+            \n  Filter: test.a = Int64(1)\
+            \n    TableScan: test";
+        assert_optimized_plan_eq(&plan, expected)?;
+
+        let custom_plan = LogicalPlan::Extension(Extension {
+            node: Arc::new(NoopPlan {
+                input: vec![table_scan.clone()],
+                schema: table_scan.schema().clone(),
+            }),
+        });
+        let plan = LogicalPlanBuilder::from(custom_plan)
+            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
+            .build()?;
+
+        // Push only predicate on `a` below NoopPlan
+        let expected = "\
+            Filter: test.c = Int64(2)\
+            \n  NoopPlan\
+            \n    Filter: test.a = Int64(1)\
+            \n      TableScan: test";
+        assert_optimized_plan_eq(&plan, expected)?;
+
+        let custom_plan = LogicalPlan::Extension(Extension {
+            node: Arc::new(NoopPlan {
+                input: vec![table_scan.clone(), table_scan.clone()],
+                schema: table_scan.schema().clone(),
+            }),
+        });
+        let plan = LogicalPlanBuilder::from(custom_plan)
+            .filter(col("a").eq(lit(1i64)))?
+            .build()?;
+
+        // Push filter below NoopPlan for each child branch
+        let expected = "\
+            NoopPlan\
+            \n  Filter: test.a = Int64(1)\
+            \n    TableScan: test\
+            \n  Filter: test.a = Int64(1)\
+            \n    TableScan: test";
+        assert_optimized_plan_eq(&plan, expected)?;
+
+        let custom_plan = LogicalPlan::Extension(Extension {
+            node: Arc::new(NoopPlan {
+                input: vec![table_scan.clone(), table_scan.clone()],
+                schema: table_scan.schema().clone(),
+            }),
+        });
+        let plan = LogicalPlanBuilder::from(custom_plan)
+            .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))?
+            .build()?;
+
+        // Push only predicate on `a` below NoopPlan
+        let expected = "\
+            Filter: test.c = Int64(2)\
+            \n  NoopPlan\
+            \n    Filter: test.a = Int64(1)\
+            \n      TableScan: test\
+            \n    Filter: test.a = Int64(1)\
+            \n      TableScan: test";
+        assert_optimized_plan_eq(&plan, expected)
+    }
+
     /// verifies that when two filters apply after an aggregation that only 
allows one to be pushed, one is pushed
     /// and the other not.
     #[test]

Reply via email to