This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b03a7a35 Dataframe join_on method (#5210)
1b03a7a35 is described below

commit 1b03a7a35aad77456cb3fca58e37612903c96aec
Author: Jeffrey <[email protected]>
AuthorDate: Thu Feb 9 21:59:59 2023 +1100

    Dataframe join_on method (#5210)
    
    * Dataframe join_on method
    
    * Fix formatting
    
    * Add tests
---
 datafusion/common/src/table_reference.rs    |  1 +
 datafusion/core/src/dataframe.rs            | 92 +++++++++++++++++++++++++++++
 datafusion/expr/src/logical_plan/builder.rs | 24 +++++++-
 datafusion/expr/src/utils.rs                | 61 ++++++++++++++++++-
 datafusion/sql/src/relation/join.rs         | 84 ++------------------------
 docs/source/user-guide/dataframe.md         |  1 +
 6 files changed, 181 insertions(+), 82 deletions(-)

diff --git a/datafusion/common/src/table_reference.rs 
b/datafusion/common/src/table_reference.rs
index 370f5e46e..1e6292b29 100644
--- a/datafusion/common/src/table_reference.rs
+++ b/datafusion/common/src/table_reference.rs
@@ -194,6 +194,7 @@ impl<'a> TableReference<'a> {
     /// failing that then taking the entire unnormalized input as the 
identifier itself.
     ///
     /// Will normalize (convert to lowercase) any unquoted identifiers.
+    ///
     /// e.g. `Foo` will be parsed as `foo`, and `"Foo"".bar"` will be parsed as
     /// `Foo".bar` (note the preserved case and requiring two double quotes to 
represent
     /// a single double quote in the identifier)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 557e04a3b..26fe5c051 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -363,6 +363,55 @@ impl DataFrame {
         Ok(DataFrame::new(self.session_state, plan))
     }
 
+    /// Join this DataFrame with another DataFrame using the specified 
expressions.
+    ///
+    /// Simply a thin wrapper over [`join`](Self::join) where the join keys 
are not provided,
+    /// and the provided expressions are AND'ed together to form the filter 
expression.
+    ///
+    /// ```
+    /// # use datafusion::prelude::*;
+    /// # use datafusion::error::Result;
+    /// # #[tokio::main]
+    /// # async fn main() -> Result<()> {
+    /// let ctx = SessionContext::new();
+    /// let left = ctx
+    ///     .read_csv("tests/data/example.csv", CsvReadOptions::new())
+    ///     .await?;
+    /// let right = ctx
+    ///     .read_csv("tests/data/example.csv", CsvReadOptions::new())
+    ///     .await?
+    ///     .select(vec![
+    ///         col("a").alias("a2"),
+    ///         col("b").alias("b2"),
+    ///         col("c").alias("c2"),
+    ///     ])?;
+    /// let join_on = left.join_on(
+    ///     right,
+    ///     JoinType::Inner,
+    ///     [col("a").not_eq(col("a2")), col("b").not_eq(col("b2"))],
+    /// )?;
+    /// let batches = join_on.collect().await?;
+    /// # Ok(())
+    /// # }
+    /// ```
+    pub fn join_on(
+        self,
+        right: DataFrame,
+        join_type: JoinType,
+        on_exprs: impl IntoIterator<Item = Expr>,
+    ) -> Result<DataFrame> {
+        let expr = on_exprs.into_iter().reduce(Expr::and);
+        let plan = LogicalPlanBuilder::from(self.plan)
+            .join(
+                right.plan,
+                join_type,
+                (Vec::<Column>::new(), Vec::<Column>::new()),
+                expr,
+            )?
+            .build()?;
+        Ok(DataFrame::new(self.session_state, plan))
+    }
+
     /// Repartition a DataFrame based on a logical partitioning scheme.
     ///
     /// ```
@@ -1039,6 +1088,49 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn join_on() -> Result<()> {
+        let left = test_table_with_name("a")
+            .await?
+            .select_columns(&["c1", "c2"])?;
+        let right = test_table_with_name("b")
+            .await?
+            .select_columns(&["c1", "c2"])?;
+        let join = left.join_on(
+            right,
+            JoinType::Inner,
+            [col("a.c1").not_eq(col("b.c1")), col("a.c2").eq(col("b.c2"))],
+        )?;
+
+        let expected_plan = "Inner Join:  Filter: a.c1 != b.c1 AND a.c2 = b.c2\
+        \n  Projection: a.c1, a.c2\
+        \n    TableScan: a\
+        \n  Projection: b.c1, b.c2\
+        \n    TableScan: b";
+        assert_eq!(expected_plan, format!("{:?}", join.logical_plan()));
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn join_ambiguous_filter() -> Result<()> {
+        let left = test_table_with_name("a")
+            .await?
+            .select_columns(&["c1", "c2"])?;
+        let right = test_table_with_name("b")
+            .await?
+            .select_columns(&["c1", "c2"])?;
+
+        let join = left
+            .join_on(right, JoinType::Inner, [col("c1").eq(col("c1"))])
+            .expect_err("join didn't fail check");
+        let expected =
+            "Error during planning: reference 'c1' is ambiguous, could be 
a.c1,b.c1;";
+        assert_eq!(join.to_string(), expected);
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn limit() -> Result<()> {
         // build query using Table API
diff --git a/datafusion/expr/src/logical_plan/builder.rs 
b/datafusion/expr/src/logical_plan/builder.rs
index 23256662f..4bbb83bb7 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -22,7 +22,10 @@ use crate::expr_rewriter::{
     normalize_cols, rewrite_sort_cols_by_aggs,
 };
 use crate::type_coercion::binary::comparison_coercion;
-use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, 
from_plan};
+use crate::utils::{
+    columnize_expr, compare_sort_expr, 
ensure_any_column_reference_is_unambiguous,
+    exprlist_to_fields, from_plan,
+};
 use crate::{and, binary_expr, Operator};
 use crate::{
     logical_plan::{
@@ -502,6 +505,25 @@ impl LogicalPlanBuilder {
             ));
         }
 
+        let filter = if let Some(expr) = filter {
+            // ambiguous check
+            ensure_any_column_reference_is_unambiguous(
+                &expr,
+                &[self.schema(), right.schema()],
+            )?;
+
+            // normalize all columns in expression
+            let using_columns = expr.to_columns()?;
+            let filter = normalize_col_with_schemas(
+                expr,
+                &[self.schema(), right.schema()],
+                &[using_columns],
+            )?;
+            Some(filter)
+        } else {
+            None
+        };
+
         let (left_keys, right_keys): (Vec<Result<Column>>, 
Vec<Result<Column>>) =
             join_keys
                 .0
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 6f64bc14f..8ce959e79 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -37,7 +37,7 @@ use datafusion_common::{
     Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, 
ScalarValue,
 };
 use std::cmp::Ordering;
-use std::collections::HashSet;
+use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
 
 ///  The value to which `COUNT(*)` is expanded to in
@@ -1023,6 +1023,65 @@ pub fn find_valid_equijoin_key_pair(
     Ok(join_key_pair)
 }
 
+/// Ensure any column reference of the expression is unambiguous.
+/// Assume we have two schema:
+/// schema1: a, b ,c
+/// schema2: a, d, e
+///
+/// `schema1.a + schema2.a` is unambiguous.
+/// `a + d` is ambiguous, because `a` may come from schema1 or schema2.
+pub fn ensure_any_column_reference_is_unambiguous(
+    expr: &Expr,
+    schemas: &[&DFSchema],
+) -> Result<()> {
+    if schemas.len() == 1 {
+        return Ok(());
+    }
+    // all referenced columns in the expression that don't have relation
+    let referenced_cols = expr.to_columns()?;
+    let mut no_relation_cols = referenced_cols
+        .iter()
+        .filter_map(|col| {
+            if col.relation.is_none() {
+                Some((col.name.as_str(), 0))
+            } else {
+                None
+            }
+        })
+        .collect::<HashMap<&str, u8>>();
+    // find the name of the column existing in multi schemas.
+    let ambiguous_col_name = schemas
+        .iter()
+        .flat_map(|schema| schema.fields())
+        .map(|field| field.name())
+        .find(|col_name| {
+            no_relation_cols.entry(col_name).and_modify(|v| *v += 1);
+            matches!(
+                no_relation_cols.get_key_value(col_name.as_str()),
+                Some((_, 2..))
+            )
+        });
+
+    if let Some(col_name) = ambiguous_col_name {
+        let maybe_field = schemas
+            .iter()
+            .flat_map(|schema| {
+                schema
+                    .field_with_unqualified_name(col_name)
+                    .map(|f| f.qualified_name())
+                    .ok()
+            })
+            .collect::<Vec<_>>();
+        Err(DataFusionError::Plan(format!(
+            "reference \'{}\' is ambiguous, could be {};",
+            col_name,
+            maybe_field.join(","),
+        )))
+    } else {
+        Ok(())
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/sql/src/relation/join.rs 
b/datafusion/sql/src/relation/join.rs
index 6f2233f39..591194136 100644
--- a/datafusion/sql/src/relation/join.rs
+++ b/datafusion/sql/src/relation/join.rs
@@ -17,11 +17,10 @@
 
 use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
 use crate::utils::normalize_ident;
-use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result};
-use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
-use datafusion_expr::{Expr, JoinType, LogicalPlan, LogicalPlanBuilder};
+use datafusion_common::{Column, DataFusionError, Result};
+use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder};
 use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
-use std::collections::{HashMap, HashSet};
+use std::collections::HashSet;
 
 impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     pub(crate) fn plan_table_with_joins(
@@ -133,30 +132,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         match constraint {
             JoinConstraint::On(sql_expr) => {
                 let join_schema = left.schema().join(right.schema())?;
-
                 // parse ON expression
                 let expr = self.sql_to_expr(sql_expr, &join_schema, 
planner_context)?;
-
-                // ambiguous check
-                ensure_any_column_reference_is_unambiguous(
-                    &expr,
-                    &[left.schema().clone(), right.schema().clone()],
-                )?;
-
-                // normalize all columns in expression
-                let using_columns = expr.to_columns()?;
-                let filter = normalize_col_with_schemas(
-                    expr,
-                    &[left.schema(), right.schema()],
-                    &[using_columns],
-                )?;
-
                 LogicalPlanBuilder::from(left)
                     .join(
                         right,
                         join_type,
                         (Vec::<Column>::new(), Vec::<Column>::new()),
-                        Some(filter),
+                        Some(expr),
                     )?
                     .build()
             }
@@ -198,62 +181,3 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         }
     }
 }
-
-/// Ensure any column reference of the expression is unambiguous.
-/// Assume we have two schema:
-/// schema1: a, b ,c
-/// schema2: a, d, e
-///
-/// `schema1.a + schema2.a` is unambiguous.
-/// `a + d` is ambiguous, because `a` may come from schema1 or schema2.
-fn ensure_any_column_reference_is_unambiguous(
-    expr: &Expr,
-    schemas: &[DFSchemaRef],
-) -> Result<()> {
-    if schemas.len() == 1 {
-        return Ok(());
-    }
-    // all referenced columns in the expression that don't have relation
-    let referenced_cols = expr.to_columns()?;
-    let mut no_relation_cols = referenced_cols
-        .iter()
-        .filter_map(|col| {
-            if col.relation.is_none() {
-                Some((col.name.as_str(), 0))
-            } else {
-                None
-            }
-        })
-        .collect::<HashMap<&str, u8>>();
-    // find the name of the column existing in multi schemas.
-    let ambiguous_col_name = schemas
-        .iter()
-        .flat_map(|schema| schema.fields())
-        .map(|field| field.name())
-        .find(|col_name| {
-            no_relation_cols.entry(col_name).and_modify(|v| *v += 1);
-            matches!(
-                no_relation_cols.get_key_value(col_name.as_str()),
-                Some((_, 2..))
-            )
-        });
-
-    if let Some(col_name) = ambiguous_col_name {
-        let maybe_field = schemas
-            .iter()
-            .flat_map(|schema| {
-                schema
-                    .field_with_unqualified_name(col_name)
-                    .map(|f| f.qualified_name())
-                    .ok()
-            })
-            .collect::<Vec<_>>();
-        Err(DataFusionError::Plan(format!(
-            "reference \'{}\' is ambiguous, could be {};",
-            col_name,
-            maybe_field.join(","),
-        )))
-    } else {
-        Ok(())
-    }
-}
diff --git a/docs/source/user-guide/dataframe.md 
b/docs/source/user-guide/dataframe.md
index 5ba803fce..c7d490e40 100644
--- a/docs/source/user-guide/dataframe.md
+++ b/docs/source/user-guide/dataframe.md
@@ -68,6 +68,7 @@ execution. The plan is evaluated (executed) when an action 
method is invoked, su
 | filter              | Filter a DataFrame to only include rows that match the 
specified filter expression.                                                    
    |
 | intersect           | Calculate the intersection of two DataFrames. The two 
DataFrames must have exactly the same schema                                    
     |
 | join                | Join this DataFrame with another DataFrame using the 
specified columns as join keys.                                                 
      |
+| join_on             | Join this DataFrame with another DataFrame using 
arbitrary expressions.                                                          
          |
 | limit               | Limit the number of rows returned from this DataFrame. 
                                                                                
    |
 | repartition         | Repartition a DataFrame based on a logical 
partitioning scheme.                                                            
                |
 | sort                | Sort the DataFrame by the specified sorting 
expressions. Any expression can be turned into a sort expression by calling its 
`sort` method. |

Reply via email to