alamb commented on code in PR #2647:
URL: https://github.com/apache/arrow-datafusion/pull/2647#discussion_r887261918


##########
datafusion/core/src/optimizer/filter_push_down.rs:
##########
@@ -178,13 +177,35 @@ fn lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
     }
 }
 
+// For a given JOIN logical plan, determine whether each side of the join is 
preserved
+// in terms on join filtering.
+// Predicates from join filter can only be pushed to preserved join side.

Review Comment:
   👍 



##########
datafusion/core/src/optimizer/filter_push_down.rs:
##########
@@ -224,32 +244,67 @@ fn optimize_join(
     plan: &LogicalPlan,
     left: &LogicalPlan,
     right: &LogicalPlan,
+    on_filter: Vec<(Expr, HashSet<Column>)>,
 ) -> Result<LogicalPlan> {
+    // Get pushable predicates from current optimizer state
     let (left_preserved, right_preserved) = lr_is_preserved(plan);
-    let to_left = get_pushable_join_predicates(&state, left.schema(), 
left_preserved);
-    let to_right = get_pushable_join_predicates(&state, right.schema(), 
right_preserved);
-
+    let to_left =

Review Comment:
   I think this code is very nicely written and easy to follow 👍 



##########
datafusion/expr/src/utils.rs:
##########
@@ -378,19 +378,24 @@ pub fn from_plan(
             join_type,
             join_constraint,
             on,
-            filter,
             null_equals_null,
             ..
         }) => {
             let schema =
                 build_join_schema(inputs[0].schema(), inputs[1].schema(), 
join_type)?;
+            let filter_expr = if on.len() * 2 == expr.len() {

Review Comment:
   ```suggestion
               // Assume that the last expr, if any,
               // is the filter_expr (non equality predicate from ON clause)
               let filter_expr = if on.len() * 2 == expr.len() {
   ```



##########
datafusion/core/src/optimizer/filter_push_down.rs:
##########
@@ -1351,7 +1427,7 @@ mod tests {
         let right = LogicalPlanBuilder::from(right_table_scan)
             .project(vec![col("a"), col("b"), col("c")])?
             .build()?;
-        let filter = col("test.a")
+        let filter = col("test.c")

Review Comment:
   Why is this changed? Because `test.a` is a join predicate as well?



##########
datafusion/sql/src/planner.rs:
##########
@@ -589,98 +589,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                             &[using_columns],
                         )
                     })
-                    .collect::<Result<Vec<_>>>()?;
+                    .collect::<Result<Vec<_>>>()?
+                    .into_iter()
+                    .reduce(Expr::and);
 
                 if left_keys.is_empty() {
                     // When we don't have join keys, use cross join
                     let join = 
LogicalPlanBuilder::from(left).cross_join(&right)?;
-                    normalized_filters
-                        .into_iter()
-                        .reduce(Expr::and)
+                    join_filter
                         .map(|filter| join.filter(filter))
                         .unwrap_or(Ok(join))?
                         .build()
-                } else if join_type == JoinType::Inner && 
!normalized_filters.is_empty() {

Review Comment:
   ❤️ 



##########
datafusion/core/src/optimizer/filter_push_down.rs:
##########
@@ -178,13 +177,35 @@ fn lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
     }
 }
 
+// For a given JOIN logical plan, determine whether each side of the join is 
preserved
+// in terms on join filtering.
+// Predicates from join filter can only be pushed to preserved join side.
+fn on_lr_is_preserved(plan: &LogicalPlan) -> (bool, bool) {
+    match plan {
+        LogicalPlan::Join(Join { join_type, .. }) => match join_type {
+            JoinType::Inner => (true, true),
+            JoinType::Left => (false, true),
+            JoinType::Right => (true, false),
+            JoinType::Full => (false, false),
+            // Semi/Anti joins can not have join filter.
+            JoinType::Semi | JoinType::Anti => unreachable!(
+                "on_lr_is_preserved cannot be appplied to SEMI/ANTI-JOIN nodes"
+            ),
+        },
+        LogicalPlan::CrossJoin(_) => {
+            unreachable!("on_lr_is_preserved cannot be applied to CROSSJOIN 
nodes")
+        }
+        _ => unreachable!("on_lr_is_preserved only valid for JOIN nodes"),
+    }
+}
+
 // Determine which predicates in state can be pushed down to a given side of a 
join.
 // To determine this, we need to know the schema of the relevant join side and 
whether
 // or not the side's rows are preserved when joining. If the side is not 
preserved, we
 // do not push down anything. Otherwise we can push down predicates where all 
of the
 // relevant columns are contained on the relevant join side's schema.
 fn get_pushable_join_predicates<'a>(
-    state: &'a State,
+    filters: &'a [(Expr, HashSet<Column>)],

Review Comment:
   I think it would help to document what the `filters` here represents, either 
in a doc string ("pairs of Exprs and a set of columns that represent ...") or 
maybe as a struct?
   
   ```rust
   struct PushFilters {
     /// expr represents an On clause predicate
     expr: Expr,
      /// columns represents columns that appear in `expr`
      cols: HashSet<Column>
   }
   ```



##########
datafusion/core/src/optimizer/filter_push_down.rs:
##########
@@ -1387,9 +1463,97 @@ mod tests {
         Ok(())
     }
 
+    /// join filter should be completely removed after pushdown
+    #[test]
+    fn join_filter_removed() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), col("b"), col("c")])?
+            .build()?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a"), col("b"), col("c")])?
+            .build()?;
+        let filter = col("test.b")
+            .gt(lit(1u32))
+            .and(col("test2.c").gt(lit(4u32)));
+        let plan = LogicalPlanBuilder::from(left)
+            .join(
+                &right,
+                JoinType::Inner,
+                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
+                Some(filter),
+            )?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Inner Join: #test.a = #test2.a Filter: #test.b > UInt32(1) AND 
#test2.c > UInt32(4)\
+            \n  Projection: #test.a, #test.b, #test.c\
+            \n    TableScan: test projection=None\
+            \n  Projection: #test2.a, #test2.b, #test2.c\
+            \n    TableScan: test2 projection=None"
+        );
+
+        let expected = "\
+        Inner Join: #test.a = #test2.a\
+        \n  Projection: #test.a, #test.b, #test.c\
+        \n    Filter: #test.b > UInt32(1)\
+        \n      TableScan: test projection=None\
+        \n  Projection: #test2.a, #test2.b, #test2.c\
+        \n    Filter: #test2.c > UInt32(4)\
+        \n      TableScan: test2 projection=None";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    /// predicate on join key in filter expression should be pushed down to 
both inputs
+    #[test]
+    fn join_filter_on_common() -> Result<()> {
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a")])?
+            .build()?;
+        let right_table_scan = test_table_scan_with_name("test2")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a")])?
+            .build()?;
+        let filter = col("test.a").gt(lit(1u32));
+        let plan = LogicalPlanBuilder::from(left)
+            .join(
+                &right,
+                JoinType::Inner,
+                (vec![Column::from_name("a")], vec![Column::from_name("a")]),
+                Some(filter),
+            )?
+            .build()?;
+
+        // not part of the test, just good to know:
+        assert_eq!(
+            format!("{:?}", plan),
+            "\
+            Inner Join: #test.a = #test2.a Filter: #test.a > UInt32(1)\
+            \n  Projection: #test.a\
+            \n    TableScan: test projection=None\
+            \n  Projection: #test2.a\
+            \n    TableScan: test2 projection=None"
+        );
+
+        let expected = "\
+        Inner Join: #test.a = #test2.a\

Review Comment:
   I recommend using different columns here -- like `test.a = test2.b` so that 
you can validate that correct column was pushed to each side



##########
datafusion/sql/src/planner.rs:
##########
@@ -3837,10 +3759,9 @@ mod tests {
             LEFT JOIN orders \
             ON id = customer_id AND order_id > 1 AND age < 30";
         let expected = "Projection: #person.id, #orders.order_id\
-        \n  Left Join: #person.id = #orders.customer_id Filter: #person.age < 
Int64(30)\
+        \n  Left Join: #person.id = #orders.customer_id Filter: 
#orders.order_id > Int64(1) AND #person.age < Int64(30)\
         \n    TableScan: person projection=None\
-        \n    Filter: #orders.order_id > Int64(1)\
-        \n      TableScan: orders projection=None";
+        \n    TableScan: orders projection=None";

Review Comment:
   So to be clear, when run normally the pushdown will still happen, but they 
are no longer pushed down by the planner, rather they are pushed down at a 
later stage.
   
   👍 



##########
datafusion/expr/src/logical_plan/plan.rs:
##########
@@ -227,9 +227,15 @@ impl LogicalPlan {
                 aggr_expr,
                 ..
             }) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
-            LogicalPlan::Join(Join { on, .. }) => on
+            LogicalPlan::Join(Join { on, filter, .. }) => on
                 .iter()
                 .flat_map(|(l, r)| vec![Expr::Column(l.clone()), 
Expr::Column(r.clone())])
+                .chain(

Review Comment:
   This is a good catch



##########
datafusion/core/src/optimizer/filter_push_down.rs:
##########
@@ -224,32 +244,67 @@ fn optimize_join(
     plan: &LogicalPlan,
     left: &LogicalPlan,
     right: &LogicalPlan,
+    on_filter: Vec<(Expr, HashSet<Column>)>,
 ) -> Result<LogicalPlan> {
+    // Get pushable predicates from current optimizer state
     let (left_preserved, right_preserved) = lr_is_preserved(plan);
-    let to_left = get_pushable_join_predicates(&state, left.schema(), 
left_preserved);
-    let to_right = get_pushable_join_predicates(&state, right.schema(), 
right_preserved);
-
+    let to_left =
+        get_pushable_join_predicates(&state.filters, left.schema(), 
left_preserved);
+    let to_right =
+        get_pushable_join_predicates(&state.filters, right.schema(), 
right_preserved);
     let to_keep: Predicates = state
         .filters
         .iter()
-        .filter(|(expr, _)| {
-            let pushed_to_left = to_left.0.contains(&expr);
-            let pushed_to_right = to_right.0.contains(&expr);
-            !pushed_to_left && !pushed_to_right
-        })
+        .filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e))
         .map(|(a, b)| (a, b))
         .unzip();
 
-    let mut left_state = state.clone();
-    left_state.filters = keep_filters(&left_state.filters, &to_left);
+    // Get pushable predicates from join filter
+    let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() {
+        ((vec![], vec![]), (vec![], vec![]), vec![])
+    } else {
+        let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan);
+        let on_to_left =
+            get_pushable_join_predicates(&on_filter, left.schema(), 
on_left_preserved);
+        let on_to_right =
+            get_pushable_join_predicates(&on_filter, right.schema(), 
on_right_preserved);
+        let on_to_keep = on_filter
+            .iter()
+            .filter(|(e, _)| !on_to_left.0.contains(&e) && 
!on_to_right.0.contains(&e))
+            .map(|(a, _)| a.clone())
+            .collect::<Vec<_>>();
+
+        (on_to_left, on_to_right, on_to_keep)
+    };
+
+    // Build new filter states using pushable predicates
+    // from current optimizer states and from ON clause.
+    // Then recursively call optimization for both join inputs
+    let mut left_state = State { filters: vec![] };
+    left_state.append_predicates(to_left);
+    left_state.append_predicates(on_to_left);
     let left = optimize(left, left_state)?;
 
-    let mut right_state = state.clone();
-    right_state.filters = keep_filters(&right_state.filters, &to_right);
+    let mut right_state = State { filters: vec![] };
+    right_state.append_predicates(to_right);
+    right_state.append_predicates(on_to_right);
     let right = optimize(right, right_state)?;
 
     // create a new Join with the new `left` and `right`
     let expr = plan.expressions();
+    let expr = if !on_filter.is_empty() && on_to_keep.is_empty() {

Review Comment:
   I don't really follow how we know the last element here in `expr` are the 
`on` expression -- doesn't that implicitly depend on the order of expressions 
returned from `Expr::expressions()`?
   
   I wonder if we can make it more explicit somehow



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to