This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1eff714ef8 Remove some Expr clones in `EliminateCrossJoin`(3%-5% 
faster planning) (#10430)
1eff714ef8 is described below

commit 1eff714ef8356dc305047386ba250b62bed6a795
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat May 11 12:36:29 2024 -0400

    Remove some Expr clones in `EliminateCrossJoin`(3%-5% faster planning) 
(#10430)
    
    * Remove some Expr clones in `EliminateCrossJoin`
    
    * Apply suggestions from code review
    
    Co-authored-by: comphead <[email protected]>
    
    * fix
    
    ---------
    
    Co-authored-by: comphead <[email protected]>
---
 datafusion/optimizer/src/eliminate_cross_join.rs | 123 +++++-------
 datafusion/optimizer/src/join_key_set.rs         | 240 +++++++++++++++++++++++
 datafusion/optimizer/src/lib.rs                  |   1 +
 3 files changed, 291 insertions(+), 73 deletions(-)

diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs 
b/datafusion/optimizer/src/eliminate_cross_join.rs
index a807ee5ff2..923be75748 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -16,11 +16,11 @@
 // under the License.
 
 //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join 
predicates are available.
-use std::collections::HashSet;
 use std::sync::Arc;
 
 use crate::{utils, OptimizerConfig, OptimizerRule};
 
+use crate::join_key_set::JoinKeySet;
 use datafusion_common::{plan_err, Result};
 use datafusion_expr::expr::{BinaryExpr, Expr};
 use datafusion_expr::logical_plan::{
@@ -55,7 +55,7 @@ impl OptimizerRule for EliminateCrossJoin {
         plan: &LogicalPlan,
         config: &dyn OptimizerConfig,
     ) -> Result<Option<LogicalPlan>> {
-        let mut possible_join_keys: Vec<(Expr, Expr)> = vec![];
+        let mut possible_join_keys = JoinKeySet::new();
         let mut all_inputs: Vec<LogicalPlan> = vec![];
         let parent_predicate = match plan {
             LogicalPlan::Filter(filter) => {
@@ -76,7 +76,7 @@ impl OptimizerRule for EliminateCrossJoin {
                         extract_possible_join_keys(
                             &filter.predicate,
                             &mut possible_join_keys,
-                        )?;
+                        );
                         Some(&filter.predicate)
                     }
                     _ => {
@@ -101,7 +101,7 @@ impl OptimizerRule for EliminateCrossJoin {
         };
 
         // Join keys are handled locally:
-        let mut all_join_keys = HashSet::<(Expr, Expr)>::new();
+        let mut all_join_keys = JoinKeySet::new();
         let mut left = all_inputs.remove(0);
         while !all_inputs.is_empty() {
             left = find_inner_join(
@@ -131,7 +131,7 @@ impl OptimizerRule for EliminateCrossJoin {
                 .map(|f| Some(LogicalPlan::Filter(f)))
         } else {
             // Remove join expressions from filter:
-            match remove_join_expressions(predicate, &all_join_keys)? {
+            match remove_join_expressions(predicate.clone(), &all_join_keys) {
                 Some(filter_expr) => Filter::try_new(filter_expr, 
Arc::new(left))
                     .map(|f| Some(LogicalPlan::Filter(f))),
                 _ => Ok(Some(left)),
@@ -150,7 +150,7 @@ impl OptimizerRule for EliminateCrossJoin {
 /// Returns a boolean indicating whether the flattening was successful.
 fn try_flatten_join_inputs(
     plan: &LogicalPlan,
-    possible_join_keys: &mut Vec<(Expr, Expr)>,
+    possible_join_keys: &mut JoinKeySet,
     all_inputs: &mut Vec<LogicalPlan>,
 ) -> Result<bool> {
     let children = match plan {
@@ -160,7 +160,7 @@ fn try_flatten_join_inputs(
                 // issue: https://github.com/apache/datafusion/issues/4844
                 return Ok(false);
             }
-            possible_join_keys.extend(join.on.clone());
+            possible_join_keys.insert_all(join.on.iter());
             vec![&join.left, &join.right]
         }
         LogicalPlan::CrossJoin(join) => {
@@ -204,8 +204,8 @@ fn try_flatten_join_inputs(
 fn find_inner_join(
     left_input: &LogicalPlan,
     rights: &mut Vec<LogicalPlan>,
-    possible_join_keys: &[(Expr, Expr)],
-    all_join_keys: &mut HashSet<(Expr, Expr)>,
+    possible_join_keys: &JoinKeySet,
+    all_join_keys: &mut JoinKeySet,
 ) -> Result<LogicalPlan> {
     for (i, right_input) in rights.iter().enumerate() {
         let mut join_keys = vec![];
@@ -228,7 +228,7 @@ fn find_inner_join(
 
         // Found one or more matching join keys
         if !join_keys.is_empty() {
-            all_join_keys.extend(join_keys.clone());
+            all_join_keys.insert_all(join_keys.iter());
             let right_input = rights.remove(i);
             let join_schema = Arc::new(build_join_schema(
                 left_input.schema(),
@@ -265,90 +265,67 @@ fn find_inner_join(
     }))
 }
 
-fn intersect(
-    accum: &mut Vec<(Expr, Expr)>,
-    vec1: &[(Expr, Expr)],
-    vec2: &[(Expr, Expr)],
-) {
-    if !(vec1.is_empty() || vec2.is_empty()) {
-        for x1 in vec1.iter() {
-            for x2 in vec2.iter() {
-                if x1.0 == x2.0 && x1.1 == x2.1 || x1.1 == x2.0 && x1.0 == 
x2.1 {
-                    accum.push((x1.0.clone(), x1.1.clone()));
-                }
-            }
-        }
-    }
-}
-
 /// Extract join keys from a WHERE clause
-fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> 
Result<()> {
+fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) {
     if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr {
         match op {
             Operator::Eq => {
-                // Ensure that we don't add the same Join keys multiple times
-                if !(accum.contains(&(*left.clone(), *right.clone()))
-                    || accum.contains(&(*right.clone(), *left.clone())))
-                {
-                    accum.push((*left.clone(), *right.clone()));
-                }
+                // insert handles ensuring  we don't add the same Join keys 
multiple times
+                join_keys.insert(left, right);
             }
             Operator::And => {
-                extract_possible_join_keys(left, accum)?;
-                extract_possible_join_keys(right, accum)?
+                extract_possible_join_keys(left, join_keys);
+                extract_possible_join_keys(right, join_keys)
             }
             // Fix for issue#78 join predicates from inside of OR expr also 
pulled up properly.
             Operator::Or => {
-                let mut left_join_keys = vec![];
-                let mut right_join_keys = vec![];
+                let mut left_join_keys = JoinKeySet::new();
+                let mut right_join_keys = JoinKeySet::new();
 
-                extract_possible_join_keys(left, &mut left_join_keys)?;
-                extract_possible_join_keys(right, &mut right_join_keys)?;
+                extract_possible_join_keys(left, &mut left_join_keys);
+                extract_possible_join_keys(right, &mut right_join_keys);
 
-                intersect(accum, &left_join_keys, &right_join_keys)
+                join_keys.insert_intersection(left_join_keys, right_join_keys)
             }
             _ => (),
         };
     }
-    Ok(())
 }
 
 /// Remove join expressions from a filter expression
-/// Returns Some() when there are few remaining predicates in filter_expr
-/// Returns None otherwise
-fn remove_join_expressions(
-    expr: &Expr,
-    join_keys: &HashSet<(Expr, Expr)>,
-) -> Result<Option<Expr>> {
+///
+/// # Returns
+/// * `Some()` when there are few remaining predicates in filter_expr
+/// * `None` otherwise
+fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option<Expr> 
{
     match expr {
-        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-            match op {
-                Operator::Eq => {
-                    if join_keys.contains(&(*left.clone(), *right.clone()))
-                        || join_keys.contains(&(*right.clone(), *left.clone()))
-                    {
-                        Ok(None)
-                    } else {
-                        Ok(Some(expr.clone()))
-                    }
-                }
-                // Fix for issue#78 join predicates from inside of OR expr 
also pulled up properly.
-                Operator::And | Operator::Or => {
-                    let l = remove_join_expressions(left, join_keys)?;
-                    let r = remove_join_expressions(right, join_keys)?;
-                    match (l, r) {
-                        (Some(ll), Some(rr)) => Ok(Some(Expr::BinaryExpr(
-                            BinaryExpr::new(Box::new(ll), *op, Box::new(rr)),
-                        ))),
-                        (Some(ll), _) => Ok(Some(ll)),
-                        (_, Some(rr)) => Ok(Some(rr)),
-                        _ => Ok(None),
-                    }
-                }
-                _ => Ok(Some(expr.clone())),
+        Expr::BinaryExpr(BinaryExpr {
+            left,
+            op: Operator::Eq,
+            right,
+        }) if join_keys.contains(&left, &right) => {
+            // was a join key, so remove it
+            None
+        }
+        // Fix for issue#78 join predicates from inside of OR expr also pulled 
up properly.
+        Expr::BinaryExpr(BinaryExpr { left, op, right })
+            if matches!(op, Operator::And | Operator::Or) =>
+        {
+            let l = remove_join_expressions(*left, join_keys);
+            let r = remove_join_expressions(*right, join_keys);
+            match (l, r) {
+                (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new(
+                    Box::new(ll),
+                    op,
+                    Box::new(rr),
+                ))),
+                (Some(ll), _) => Some(ll),
+                (_, Some(rr)) => Some(rr),
+                _ => None,
             }
         }
-        _ => Ok(Some(expr.clone())),
+
+        _ => Some(expr),
     }
 }
 
diff --git a/datafusion/optimizer/src/join_key_set.rs 
b/datafusion/optimizer/src/join_key_set.rs
new file mode 100644
index 0000000000..c47afa012c
--- /dev/null
+++ b/datafusion/optimizer/src/join_key_set.rs
@@ -0,0 +1,240 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! [JoinKeySet] for tracking the set of join keys in a plan.
+
+use datafusion_expr::Expr;
+use indexmap::{Equivalent, IndexSet};
+
+/// Tracks a set of equality Join keys
+///
+/// A join key is an expression that is used to join two tables via an equality
+/// predicate such as `a.x = b.y`
+///
+/// This struct models `a.x + 5 = b.y AND a.z = b.z` as two join keys
+/// 1. `(a.x + 5,  b.y)`
+/// 2. `(a.z,      b.z)`
+///
+/// # Important properties:
+///
+/// 1. Retains insert order
+/// 2. Can quickly look up if a pair of expressions are in the set.
+#[derive(Debug)]
+pub struct JoinKeySet {
+    inner: IndexSet<(Expr, Expr)>,
+}
+
+impl JoinKeySet {
+    /// Create a new empty set
+    pub fn new() -> Self {
+        Self {
+            inner: IndexSet::new(),
+        }
+    }
+
+    /// Return true if the set contains a join pair
+    /// where left = right or right = left
+    pub fn contains(&self, left: &Expr, right: &Expr) -> bool {
+        self.inner.contains(&ExprPair::new(left, right))
+            || self.inner.contains(&ExprPair::new(right, left))
+    }
+
+    /// Insert the join key `(left = right)` into the set  if join pair 
`(right =
+    /// left)` is not already in the set
+    ///
+    /// returns true if the pair was inserted
+    pub fn insert(&mut self, left: &Expr, right: &Expr) -> bool {
+        if self.contains(left, right) {
+            false
+        } else {
+            self.inner.insert((left.clone(), right.clone()));
+            true
+        }
+    }
+
+    /// Inserts potentially many join keys into the set, copying only when 
necessary
+    ///
+    /// returns true if any of the pairs were inserted
+    pub fn insert_all<'a>(
+        &mut self,
+        iter: impl Iterator<Item = &'a (Expr, Expr)>,
+    ) -> bool {
+        let mut inserted = false;
+        for (left, right) in iter {
+            inserted |= self.insert(left, right);
+        }
+        inserted
+    }
+
+    /// Inserts any join keys that are common to both `s1` and `s2` into self
+    pub fn insert_intersection(&mut self, s1: JoinKeySet, s2: JoinKeySet) {
+        // note can't use inner.intersection as we need to consider both (l, r)
+        // and (r, l) in equality
+        for (left, right) in s1.inner.iter() {
+            if s2.contains(left, right) {
+                self.insert(left, right);
+            }
+        }
+    }
+
+    /// returns true if this set is empty
+    pub fn is_empty(&self) -> bool {
+        self.inner.is_empty()
+    }
+
+    /// Return the length of this set
+    #[cfg(test)]
+    pub fn len(&self) -> usize {
+        self.inner.len()
+    }
+
+    /// Return an iterator over the join keys in this set
+    pub fn iter(&self) -> impl Iterator<Item = (&Expr, &Expr)> {
+        self.inner.iter().map(|(l, r)| (l, r))
+    }
+}
+
+/// Custom comparison operation to avoid copying owned values
+///
+/// This behaves like a `(Expr, Expr)` tuple for hashing and  comparison, but
+/// avoids copying the values simply to comparing them.
+
+#[derive(Debug, Eq, PartialEq, Hash)]
+struct ExprPair<'a>(&'a Expr, &'a Expr);
+
+impl<'a> ExprPair<'a> {
+    fn new(left: &'a Expr, right: &'a Expr) -> Self {
+        Self(left, right)
+    }
+}
+
+impl<'a> Equivalent<(Expr, Expr)> for ExprPair<'a> {
+    fn equivalent(&self, other: &(Expr, Expr)) -> bool {
+        self.0 == &other.0 && self.1 == &other.1
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use crate::join_key_set::JoinKeySet;
+    use datafusion_expr::{col, Expr};
+
+    #[test]
+    fn test_insert() {
+        let mut set = JoinKeySet::new();
+        // new sets should be empty
+        assert!(set.is_empty());
+
+        // insert (a = b)
+        assert!(set.insert(&col("a"), &col("b")));
+        assert!(!set.is_empty());
+
+        // insert (a=b) again returns false
+        assert!(!set.insert(&col("a"), &col("b")));
+        assert_eq!(set.len(), 1);
+
+        // insert (b = a) , should be considered equivalent
+        assert!(!set.insert(&col("b"), &col("a")));
+        assert_eq!(set.len(), 1);
+
+        // insert (a = c) should be considered different
+        assert!(set.insert(&col("a"), &col("c")));
+        assert_eq!(set.len(), 2);
+    }
+
+    #[test]
+    fn test_contains() {
+        let mut set = JoinKeySet::new();
+        assert!(set.insert(&col("a"), &col("b")));
+        assert!(set.contains(&col("a"), &col("b")));
+        assert!(set.contains(&col("b"), &col("a")));
+        assert!(!set.contains(&col("a"), &col("c")));
+
+        assert!(set.insert(&col("a"), &col("c")));
+        assert!(set.contains(&col("a"), &col("c")));
+        assert!(set.contains(&col("c"), &col("a")));
+    }
+
+    #[test]
+    fn test_iterator() {
+        // put in c = a and
+        let mut set = JoinKeySet::new();
+        // put in c = a , b = c, and a = c and expect to get only the first 2
+        set.insert(&col("c"), &col("a"));
+        set.insert(&col("b"), &col("c"));
+        set.insert(&col("a"), &col("c"));
+        assert_contents(&set, vec![(&col("c"), &col("a")), (&col("b"), 
&col("c"))]);
+    }
+
+    #[test]
+    fn test_insert_intersection() {
+        // a = b, b = c, c = d
+        let mut set1 = JoinKeySet::new();
+        set1.insert(&col("a"), &col("b"));
+        set1.insert(&col("b"), &col("c"));
+        set1.insert(&col("c"), &col("d"));
+
+        // a = a, b = b, b = c, d = c
+        // should only intersect on b = c and c = d
+        let mut set2 = JoinKeySet::new();
+        set2.insert(&col("a"), &col("a"));
+        set2.insert(&col("b"), &col("b"));
+        set2.insert(&col("b"), &col("c"));
+        set2.insert(&col("d"), &col("c"));
+
+        let mut set = JoinKeySet::new();
+        // put something in there already
+        set.insert(&col("x"), &col("y"));
+        set.insert_intersection(set1, set2);
+
+        assert_contents(
+            &set,
+            vec![
+                (&col("x"), &col("y")),
+                (&col("b"), &col("c")),
+                (&col("c"), &col("d")),
+            ],
+        );
+    }
+
+    fn assert_contents(set: &JoinKeySet, expected: Vec<(&Expr, &Expr)>) {
+        let contents: Vec<_> = set.iter().collect();
+        assert_eq!(contents, expected);
+    }
+
+    #[test]
+    fn test_insert_many() {
+        let mut set = JoinKeySet::new();
+
+        // insert (a=b), (b=c), (b=a)
+        set.insert_all(
+            vec![
+                &(col("a"), col("b")),
+                &(col("b"), col("c")),
+                &(col("b"), col("a")),
+            ]
+            .into_iter(),
+        );
+        assert_eq!(set.len(), 2);
+        assert!(set.contains(&col("a"), &col("b")));
+        assert!(set.contains(&col("b"), &col("c")));
+        assert!(set.contains(&col("b"), &col("a")));
+
+        // should not contain (a=c)
+        assert!(!set.contains(&col("a"), &col("c")));
+    }
+}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 9176d67c1d..793c87f8bc 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -62,6 +62,7 @@ pub use analyzer::{Analyzer, AnalyzerRule};
 pub use optimizer::{Optimizer, OptimizerConfig, OptimizerContext, 
OptimizerRule};
 pub use utils::optimize_children;
 
+pub(crate) mod join_key_set;
 mod plan_signature;
 
 #[cfg(test)]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to