This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new e02376ddc Pushdown  single column predicates from ON join clauses 
(#3578)
e02376ddc is described below

commit e02376ddc431a818e1f19a5bb16fe45307a512e8
Author: AssHero <[email protected]>
AuthorDate: Sat Oct 15 21:46:53 2022 +0800

    Pushdown  single column predicates from ON join clauses (#3578)
    
    * extract OR clause for join
    
    * add more comments
    
    * add some comments
    
    * Update TPCH plans
    
    * Update test plan
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 benchmarks/expected-plans/q19.txt            |   6 +-
 benchmarks/expected-plans/q7.txt             |   8 +-
 benchmarks/src/bin/tpch.rs                   |   5 +-
 datafusion/core/tests/sql/joins.rs           |   3 +-
 datafusion/core/tests/sql/predicates.rs      |   8 +
 datafusion/optimizer/src/filter_push_down.rs | 223 ++++++++++++++++++++++++++-
 6 files changed, 243 insertions(+), 10 deletions(-)

diff --git a/benchmarks/expected-plans/q19.txt 
b/benchmarks/expected-plans/q19.txt
index 902893ea9..cbf4f08b3 100644
--- a/benchmarks/expected-plans/q19.txt
+++ b/benchmarks/expected-plans/q19.txt
@@ -3,7 +3,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - 
lineitem.l_discount) AS re
     Projection: lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") AS 
lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")Utf8("DELIVER IN 
PERSON")lineitem.l_shipinstruct, lineitem.l_shipmode IN ([Utf8("AIR"), 
Utf8("AIR REG")]) AS lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR 
REG")])Utf8("AIR REG")Utf8("AIR")lineitem.l_shipmode, part.p_size >= Int32(1) 
AS part.p_size >= Int32(1)Int32(1)part.p_size, lineitem.l_quantity, 
lineitem.l_extendedprice, lineitem.l_discount, part.p_brand, part.p [...]
       Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN 
([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND 
lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = 
Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), 
Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= 
Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Dec [...]
         Inner Join: lineitem.l_partkey = part.p_partkey
-          Filter: lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND 
lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
+          Filter: lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND 
lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") AND lineitem.l_quantity >= 
Decimal128(Some(100),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= 
Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= 
Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(3000),15,2)
             TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice, l_discount, l_shipinstruct, l_shipmode]
-          Filter: part.p_size >= Int32(1)
-            TableScan: part projection=[p_partkey, p_brand, p_size, 
p_container]
\ No newline at end of file
+          Filter: part.p_size >= Int32(1) AND part.p_brand = Utf8("Brand#12") 
AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), 
Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") 
AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), 
Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = 
Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), 
Utf8("LG PACK"), Utf8("LG PKG")]) AND p [...]
+            TableScan: part projection=[p_partkey, p_brand, p_size, 
p_container]
diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt
index 4a2866a42..a1d1806f9 100644
--- a/benchmarks/expected-plans/q7.txt
+++ b/benchmarks/expected-plans/q7.txt
@@ -14,7 +14,9 @@ Sort: shipping.supp_nation ASC NULLS LAST, 
shipping.cust_nation ASC NULLS LAST,
                         TableScan: lineitem projection=[l_orderkey, l_suppkey, 
l_extendedprice, l_discount, l_shipdate]
                     TableScan: orders projection=[o_orderkey, o_custkey]
                   TableScan: customer projection=[c_custkey, c_nationkey]
-                SubqueryAlias: n1
+                Filter: n1.n_name = Utf8("FRANCE") OR n1.n_name = 
Utf8("GERMANY")
+                  SubqueryAlias: n1
+                    TableScan: nation projection=[n_nationkey, n_name]
+              Filter: n2.n_name = Utf8("GERMANY") OR n2.n_name = Utf8("FRANCE")
+                SubqueryAlias: n2
                   TableScan: nation projection=[n_nationkey, n_name]
-              SubqueryAlias: n2
-                TableScan: nation projection=[n_nationkey, n_name]
\ No newline at end of file
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index 7930cb73c..d1aa2c8eb 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -792,7 +792,10 @@ mod tests {
         for path in &possibilities {
             let path = Path::new(&path);
             if let Ok(expected) = read_text_file(path) {
-                assert_eq!(expected, actual);
+                assert_eq!(expected, actual,
+                           // generate output that is easier to 
copy/paste/update
+                           "\n\nMismatch of expected content in: 
{:?}\nExpected:\n\n{}\n\nActual:\n\n{}\n\n",
+                           path, expected, actual);
                 found = true;
                 break;
             }
diff --git a/datafusion/core/tests/sql/joins.rs 
b/datafusion/core/tests/sql/joins.rs
index b5b59b1b6..2ff4947b3 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1474,7 +1474,8 @@ async fn reduce_left_join_2() -> Result<()> {
         "    Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS 
Int64) > Int64(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
         "      Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, 
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, 
t2_int:UInt32;N]",
         "        TableScan: t1 projection=[t1_id, t1_name, t1_int] 
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-        "        TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "        Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != 
Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "          TableScan: t2 projection=[t2_id, t2_name, t2_int] 
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
     ];
     let formatted = plan.display_indent_schema().to_string();
     let actual: Vec<&str> = formatted.trim().lines().collect();
diff --git a/datafusion/core/tests/sql/predicates.rs 
b/datafusion/core/tests/sql/predicates.rs
index bb4391c4f..07e016a27 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -468,6 +468,14 @@ async fn multiple_or_predicates() -> Result<()> {
     // factored out and appear only once in the following plan
     let expected = vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #lineitem.l_partkey [l_partkey:Int64]",
+        "    Projection: #part.p_size >= Int32(1) AS #part.p_size >= 
Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, 
#part.p_brand, #part.p_size [#part.p_size >= 
Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, 
l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
+        "      Filter: #part.p_brand = Utf8(\"Brand#12\") AND 
#lineitem.l_quantity >= Decimal128(Some(100),15,2) AND #lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) AND #part.p_size <= Int32(5) OR #part.p_brand = 
Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND 
#lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND #part.p_size <= 
Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= 
Decimal128(Some(2000),15,2) AND #lineitem.l_quan [...]
+        "        Inner Join: #lineitem.l_partkey = #part.p_partkey 
[l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, 
p_size:Int32]",
+        "          Filter: #lineitem.l_quantity >= Decimal128(Some(100),15,2) 
AND #lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR #lineitem.l_quantity 
>= Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) OR #lineitem.l_quantity >= 
Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= 
Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
+        "            TableScan: lineitem projection=[l_partkey, l_quantity], 
partial_filters=[#lineitem.l_quantity >= Decimal128(Some(100),15,2) AND 
#lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR #lineitem.l_quantity >= 
Decimal128(Some(1000),15,2) AND #lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) OR #lineitem.l_quantity >= 
Decimal128(Some(2000),15,2) AND #lineitem.l_quantity <= 
Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
+        "          Filter: #part.p_size >= Int32(1) AND #part.p_brand = 
Utf8(\"Brand#12\") AND #part.p_size <= Int32(5) OR #part.p_brand = 
Utf8(\"Brand#23\") AND #part.p_size <= Int32(10) OR #part.p_brand = 
Utf8(\"Brand#34\") AND #part.p_size <= Int32(15) [p_partkey:Int64, 
p_brand:Utf8, p_size:Int32]",
+        "            TableScan: part projection=[p_partkey, p_brand, p_size], 
partial_filters=[#part.p_size >= Int32(1), #part.p_brand = Utf8(\"Brand#12\") 
AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND 
#part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND 
#part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
         "  Projection: lineitem.l_partkey [l_partkey:Int64]",
         "    Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity 
>= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(1100),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND 
Int64(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= 
Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= 
Decimal128(Some(2000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND 
Int64(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quanti [...]
         "      Inner Join: lineitem.l_partkey = part.p_partkey 
[l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, 
p_size:Int32]",
diff --git a/datafusion/optimizer/src/filter_push_down.rs 
b/datafusion/optimizer/src/filter_push_down.rs
index 4d720eb22..08ba71cda 100644
--- a/datafusion/optimizer/src/filter_push_down.rs
+++ b/datafusion/optimizer/src/filter_push_down.rs
@@ -17,14 +17,15 @@
 use crate::{utils, OptimizerConfig, OptimizerRule};
 use datafusion_common::{Column, DFSchema, DataFusionError, Result};
 use datafusion_expr::{
-    col,
+    and, col,
     expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
     logical_plan::{
         Aggregate, CrossJoin, Join, JoinType, Limit, LogicalPlan, Projection, 
TableScan,
         Union,
     },
+    or,
     utils::{expr_to_columns, exprlist_to_columns, from_plan},
-    Expr, TableProviderFilterPushDown,
+    Expr, Operator, TableProviderFilterPushDown,
 };
 use std::collections::{HashMap, HashSet};
 use std::iter::once;
@@ -247,6 +248,156 @@ fn get_pushable_join_predicates<'a>(
         .unzip()
 }
 
+// examine OR clause to see if any useful clauses can be extracted and push 
down.
+// extract at least one qual from each sub clauses of OR clause, then form the 
quals
+// to new OR clause as predicate.
+//
+// Filter: (a = c and a < 20) or (b = d and b > 10)
+//     join/crossjoin:
+//          TableScan: projection=[a, b]
+//          TableScan: projection=[c, d]
+//
+// is optimized to
+//
+// Filter: (a = c and a < 20) or (b = d and b > 10)
+//     join/crossjoin:
+//          Filter: (a < 20) or (b > 10)
+//              TableScan: projection=[a, b]
+//          TableScan: projection=[c, d]
+//
+// In general, predicates of this form:
+//
+// (A AND B) OR (C AND D)
+//
+// will be transformed to
+//
+// ((A AND B) OR (C AND D)) AND (A OR C)
+//
+// OR
+//
+// ((A AND B) OR (C AND D)) AND ((A AND B) OR C)
+//
+// OR
+//
+// do nothing.
+//
+fn extract_or_clauses_for_join(
+    filters: &[&Expr],
+    schema: &DFSchema,
+    preserved: bool,
+) -> (Vec<Expr>, Vec<HashSet<Column>>) {
+    if !preserved {
+        return (vec![], vec![]);
+    }
+
+    let schema_columns = schema
+        .fields()
+        .iter()
+        .flat_map(|f| {
+            [
+                f.qualified_column(),
+                // we need to push down filter using unqualified column as well
+                f.unqualified_column(),
+            ]
+        })
+        .collect::<HashSet<_>>();
+
+    let mut exprs = vec![];
+    let mut expr_columns = vec![];
+    for expr in filters.iter() {
+        if let Expr::BinaryExpr {
+            left,
+            op: Operator::Or,
+            right,
+        } = expr
+        {
+            let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
+            let right_expr = extract_or_clause(right.as_ref(), 
&schema_columns);
+
+            // If nothing can be extracted from any sub clauses, do nothing 
for this OR clause.
+            if let (Some(left_expr), Some(right_expr)) = (left_expr, 
right_expr) {
+                let predicate = or(left_expr, right_expr);
+                let mut columns: HashSet<Column> = HashSet::new();
+                expr_to_columns(&predicate, &mut columns).ok().unwrap();
+
+                exprs.push(predicate);
+                expr_columns.push(columns);
+            }
+        }
+    }
+
+    // new formed OR clauses and their column references
+    (exprs, expr_columns)
+}
+
+// extract qual from OR sub-clause.
+//
+// A qual is extracted if it only contains set of column references in 
schema_columns.
+//
+// For AND clause, we extract from both sub-clauses, then make new AND clause 
by extracted
+// clauses if both extracted; Otherwise, use the extracted clause from any 
sub-clauses or None.
+//
+// For OR clause, we extract from both sub-clauses, then make new OR clause by 
extracted clauses if both extracted;
+// Otherwise, return None.
+//
+// For other clause, apply the rule above to extract clause.
+fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> 
Option<Expr> {
+    let mut predicate = None;
+
+    match expr {
+        Expr::BinaryExpr {
+            left: l_expr,
+            op: Operator::Or,
+            right: r_expr,
+        } => {
+            let l_expr = extract_or_clause(l_expr, schema_columns);
+            let r_expr = extract_or_clause(r_expr, schema_columns);
+
+            if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
+                predicate = Some(or(l_expr, r_expr));
+            }
+        }
+        Expr::BinaryExpr {
+            left: l_expr,
+            op: Operator::And,
+            right: r_expr,
+        } => {
+            let l_expr = extract_or_clause(l_expr, schema_columns);
+            let r_expr = extract_or_clause(r_expr, schema_columns);
+
+            match (l_expr, r_expr) {
+                (Some(l_expr), Some(r_expr)) => {
+                    predicate = Some(and(l_expr, r_expr));
+                }
+                (Some(l_expr), None) => {
+                    predicate = Some(l_expr);
+                }
+                (None, Some(r_expr)) => {
+                    predicate = Some(r_expr);
+                }
+                (None, None) => {
+                    predicate = None;
+                }
+            }
+        }
+        _ => {
+            let mut columns: HashSet<Column> = HashSet::new();
+            expr_to_columns(expr, &mut columns).ok().unwrap();
+
+            if schema_columns
+                .intersection(&columns)
+                .collect::<HashSet<_>>()
+                .len()
+                == columns.len()
+            {
+                predicate = Some(expr.clone());
+            }
+        }
+    }
+
+    predicate
+}
+
 fn optimize_join(
     mut state: State,
     plan: &LogicalPlan,
@@ -285,17 +436,54 @@ fn optimize_join(
         (on_to_left, on_to_right, on_to_keep)
     };
 
+    // Extract from OR clause, generate new predicates for both side of join 
if possible.
+    // We only track the unpushable predicates above.
+    let or_to_left =
+        extract_or_clauses_for_join(&to_keep.0, left.schema(), left_preserved);
+    let or_to_right =
+        extract_or_clauses_for_join(&to_keep.0, right.schema(), 
right_preserved);
+    let on_or_to_left = extract_or_clauses_for_join(
+        &on_to_keep.iter().collect::<Vec<_>>(),
+        left.schema(),
+        left_preserved,
+    );
+    let on_or_to_right = extract_or_clauses_for_join(
+        &on_to_keep.iter().collect::<Vec<_>>(),
+        right.schema(),
+        right_preserved,
+    );
+
     // Build new filter states using pushable predicates
     // from current optimizer states and from ON clause.
     // Then recursively call optimization for both join inputs
     let mut left_state = State { filters: vec![] };
     left_state.append_predicates(to_left);
     left_state.append_predicates(on_to_left);
+    or_to_left
+        .0
+        .into_iter()
+        .zip(or_to_left.1)
+        .for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
+    on_or_to_left
+        .0
+        .into_iter()
+        .zip(on_or_to_left.1)
+        .for_each(|(expr, cols)| left_state.filters.push((expr, cols)));
     let left = optimize(left, left_state)?;
 
     let mut right_state = State { filters: vec![] };
     right_state.append_predicates(to_right);
     right_state.append_predicates(on_to_right);
+    or_to_right
+        .0
+        .into_iter()
+        .zip(or_to_right.1)
+        .for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
+    on_or_to_right
+        .0
+        .into_iter()
+        .zip(on_or_to_right.1)
+        .for_each(|(expr, cols)| right_state.filters.push((expr, cols)));
     let right = optimize(right, right_state)?;
 
     // Create a new Join with the new `left` and `right`
@@ -2134,4 +2322,35 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_crossjoin_with_or_clause() -> Result<()> {
+        // select * from test,test1 where (test.a = test1.a and test.b > 1) or 
(test.b = test1.b and test.c < 10);
+        let table_scan = test_table_scan()?;
+        let left = LogicalPlanBuilder::from(table_scan)
+            .project(vec![col("a"), col("b"), col("c")])?
+            .build()?;
+        let right_table_scan = test_table_scan_with_name("test1")?;
+        let right = LogicalPlanBuilder::from(right_table_scan)
+            .project(vec![col("a").alias("d"), col("a").alias("e")])?
+            .build()?;
+        let filter = or(
+            and(col("a").eq(col("d")), col("b").gt(lit(1u32))),
+            and(col("b").eq(col("e")), col("c").lt(lit(10u32))),
+        );
+        let plan = LogicalPlanBuilder::from(left)
+            .cross_join(&right)?
+            .filter(filter)?
+            .build()?;
+
+        let expected = "Filter: test.a = d AND test.b > UInt32(1) OR test.b = 
e AND test.c < UInt32(10)\
+                        \n  CrossJoin:\
+                        \n    Projection: test.a, test.b, test.c\
+                        \n      Filter: test.b > UInt32(1) OR test.c < 
UInt32(10)\
+                        \n        TableScan: test\
+                        \n    Projection: test1.a AS d, test1.a AS e\
+                        \n      TableScan: test1";
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
 }

Reply via email to