This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 1b03a7a35 Dataframe join_on method (#5210)
1b03a7a35 is described below
commit 1b03a7a35aad77456cb3fca58e37612903c96aec
Author: Jeffrey <[email protected]>
AuthorDate: Thu Feb 9 21:59:59 2023 +1100
Dataframe join_on method (#5210)
* Dataframe join_on method
* Fix formatting
* Add tests
---
datafusion/common/src/table_reference.rs | 1 +
datafusion/core/src/dataframe.rs | 92 +++++++++++++++++++++++++++++
datafusion/expr/src/logical_plan/builder.rs | 24 +++++++-
datafusion/expr/src/utils.rs | 61 ++++++++++++++++++-
datafusion/sql/src/relation/join.rs | 84 ++------------------------
docs/source/user-guide/dataframe.md | 1 +
6 files changed, 181 insertions(+), 82 deletions(-)
diff --git a/datafusion/common/src/table_reference.rs
b/datafusion/common/src/table_reference.rs
index 370f5e46e..1e6292b29 100644
--- a/datafusion/common/src/table_reference.rs
+++ b/datafusion/common/src/table_reference.rs
@@ -194,6 +194,7 @@ impl<'a> TableReference<'a> {
/// failing that then taking the entire unnormalized input as the
identifier itself.
///
/// Will normalize (convert to lowercase) any unquoted identifiers.
+ ///
/// e.g. `Foo` will be parsed as `foo`, and `"Foo"".bar"` will be parsed as
/// `Foo".bar` (note the preserved case and requiring two double quotes to
represent
/// a single double quote in the identifier)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 557e04a3b..26fe5c051 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -363,6 +363,55 @@ impl DataFrame {
Ok(DataFrame::new(self.session_state, plan))
}
+ /// Join this DataFrame with another DataFrame using the specified
expressions.
+ ///
+ /// Simply a thin wrapper over [`join`](Self::join) where the join keys
are not provided,
+ /// and the provided expressions are AND'ed together to form the filter
expression.
+ ///
+ /// ```
+ /// # use datafusion::prelude::*;
+ /// # use datafusion::error::Result;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<()> {
+ /// let ctx = SessionContext::new();
+ /// let left = ctx
+ /// .read_csv("tests/data/example.csv", CsvReadOptions::new())
+ /// .await?;
+ /// let right = ctx
+ /// .read_csv("tests/data/example.csv", CsvReadOptions::new())
+ /// .await?
+ /// .select(vec![
+ /// col("a").alias("a2"),
+ /// col("b").alias("b2"),
+ /// col("c").alias("c2"),
+ /// ])?;
+ /// let join_on = left.join_on(
+ /// right,
+ /// JoinType::Inner,
+ /// [col("a").not_eq(col("a2")), col("b").not_eq(col("b2"))],
+ /// )?;
+ /// let batches = join_on.collect().await?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn join_on(
+ self,
+ right: DataFrame,
+ join_type: JoinType,
+ on_exprs: impl IntoIterator<Item = Expr>,
+ ) -> Result<DataFrame> {
+ let expr = on_exprs.into_iter().reduce(Expr::and);
+ let plan = LogicalPlanBuilder::from(self.plan)
+ .join(
+ right.plan,
+ join_type,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ expr,
+ )?
+ .build()?;
+ Ok(DataFrame::new(self.session_state, plan))
+ }
+
/// Repartition a DataFrame based on a logical partitioning scheme.
///
/// ```
@@ -1039,6 +1088,49 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn join_on() -> Result<()> {
+ let left = test_table_with_name("a")
+ .await?
+ .select_columns(&["c1", "c2"])?;
+ let right = test_table_with_name("b")
+ .await?
+ .select_columns(&["c1", "c2"])?;
+ let join = left.join_on(
+ right,
+ JoinType::Inner,
+ [col("a.c1").not_eq(col("b.c1")), col("a.c2").eq(col("b.c2"))],
+ )?;
+
+ let expected_plan = "Inner Join: Filter: a.c1 != b.c1 AND a.c2 = b.c2\
+ \n Projection: a.c1, a.c2\
+ \n TableScan: a\
+ \n Projection: b.c1, b.c2\
+ \n TableScan: b";
+ assert_eq!(expected_plan, format!("{:?}", join.logical_plan()));
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn join_ambiguous_filter() -> Result<()> {
+ let left = test_table_with_name("a")
+ .await?
+ .select_columns(&["c1", "c2"])?;
+ let right = test_table_with_name("b")
+ .await?
+ .select_columns(&["c1", "c2"])?;
+
+ let join = left
+ .join_on(right, JoinType::Inner, [col("c1").eq(col("c1"))])
+ .expect_err("join didn't fail check");
+ let expected =
+ "Error during planning: reference 'c1' is ambiguous, could be
a.c1,b.c1;";
+ assert_eq!(join.to_string(), expected);
+
+ Ok(())
+ }
+
#[tokio::test]
async fn limit() -> Result<()> {
// build query using Table API
diff --git a/datafusion/expr/src/logical_plan/builder.rs
b/datafusion/expr/src/logical_plan/builder.rs
index 23256662f..4bbb83bb7 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -22,7 +22,10 @@ use crate::expr_rewriter::{
normalize_cols, rewrite_sort_cols_by_aggs,
};
use crate::type_coercion::binary::comparison_coercion;
-use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields,
from_plan};
+use crate::utils::{
+ columnize_expr, compare_sort_expr,
ensure_any_column_reference_is_unambiguous,
+ exprlist_to_fields, from_plan,
+};
use crate::{and, binary_expr, Operator};
use crate::{
logical_plan::{
@@ -502,6 +505,25 @@ impl LogicalPlanBuilder {
));
}
+ let filter = if let Some(expr) = filter {
+ // ambiguous check
+ ensure_any_column_reference_is_unambiguous(
+ &expr,
+ &[self.schema(), right.schema()],
+ )?;
+
+ // normalize all columns in expression
+ let using_columns = expr.to_columns()?;
+ let filter = normalize_col_with_schemas(
+ expr,
+ &[self.schema(), right.schema()],
+ &[using_columns],
+ )?;
+ Some(filter)
+ } else {
+ None
+ };
+
let (left_keys, right_keys): (Vec<Result<Column>>,
Vec<Result<Column>>) =
join_keys
.0
diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs
index 6f64bc14f..8ce959e79 100644
--- a/datafusion/expr/src/utils.rs
+++ b/datafusion/expr/src/utils.rs
@@ -37,7 +37,7 @@ use datafusion_common::{
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
ScalarValue,
};
use std::cmp::Ordering;
-use std::collections::HashSet;
+use std::collections::{HashMap, HashSet};
use std::sync::Arc;
/// The value to which `COUNT(*)` is expanded to in
@@ -1023,6 +1023,65 @@ pub fn find_valid_equijoin_key_pair(
Ok(join_key_pair)
}
+/// Ensure any column reference of the expression is unambiguous.
+/// Assume we have two schema:
+/// schema1: a, b ,c
+/// schema2: a, d, e
+///
+/// `schema1.a + schema2.a` is unambiguous.
+/// `a + d` is ambiguous, because `a` may come from schema1 or schema2.
+pub fn ensure_any_column_reference_is_unambiguous(
+ expr: &Expr,
+ schemas: &[&DFSchema],
+) -> Result<()> {
+ if schemas.len() == 1 {
+ return Ok(());
+ }
+ // all referenced columns in the expression that don't have relation
+ let referenced_cols = expr.to_columns()?;
+ let mut no_relation_cols = referenced_cols
+ .iter()
+ .filter_map(|col| {
+ if col.relation.is_none() {
+ Some((col.name.as_str(), 0))
+ } else {
+ None
+ }
+ })
+ .collect::<HashMap<&str, u8>>();
+ // find the name of the column existing in multi schemas.
+ let ambiguous_col_name = schemas
+ .iter()
+ .flat_map(|schema| schema.fields())
+ .map(|field| field.name())
+ .find(|col_name| {
+ no_relation_cols.entry(col_name).and_modify(|v| *v += 1);
+ matches!(
+ no_relation_cols.get_key_value(col_name.as_str()),
+ Some((_, 2..))
+ )
+ });
+
+ if let Some(col_name) = ambiguous_col_name {
+ let maybe_field = schemas
+ .iter()
+ .flat_map(|schema| {
+ schema
+ .field_with_unqualified_name(col_name)
+ .map(|f| f.qualified_name())
+ .ok()
+ })
+ .collect::<Vec<_>>();
+ Err(DataFusionError::Plan(format!(
+ "reference \'{}\' is ambiguous, could be {};",
+ col_name,
+ maybe_field.join(","),
+ )))
+ } else {
+ Ok(())
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/sql/src/relation/join.rs
b/datafusion/sql/src/relation/join.rs
index 6f2233f39..591194136 100644
--- a/datafusion/sql/src/relation/join.rs
+++ b/datafusion/sql/src/relation/join.rs
@@ -17,11 +17,10 @@
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use crate::utils::normalize_ident;
-use datafusion_common::{Column, DFSchemaRef, DataFusionError, Result};
-use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
-use datafusion_expr::{Expr, JoinType, LogicalPlan, LogicalPlanBuilder};
+use datafusion_common::{Column, DataFusionError, Result};
+use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder};
use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins};
-use std::collections::{HashMap, HashSet};
+use std::collections::HashSet;
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
pub(crate) fn plan_table_with_joins(
@@ -133,30 +132,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
match constraint {
JoinConstraint::On(sql_expr) => {
let join_schema = left.schema().join(right.schema())?;
-
// parse ON expression
let expr = self.sql_to_expr(sql_expr, &join_schema,
planner_context)?;
-
- // ambiguous check
- ensure_any_column_reference_is_unambiguous(
- &expr,
- &[left.schema().clone(), right.schema().clone()],
- )?;
-
- // normalize all columns in expression
- let using_columns = expr.to_columns()?;
- let filter = normalize_col_with_schemas(
- expr,
- &[left.schema(), right.schema()],
- &[using_columns],
- )?;
-
LogicalPlanBuilder::from(left)
.join(
right,
join_type,
(Vec::<Column>::new(), Vec::<Column>::new()),
- Some(filter),
+ Some(expr),
)?
.build()
}
@@ -198,62 +181,3 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
}
-
-/// Ensure any column reference of the expression is unambiguous.
-/// Assume we have two schema:
-/// schema1: a, b ,c
-/// schema2: a, d, e
-///
-/// `schema1.a + schema2.a` is unambiguous.
-/// `a + d` is ambiguous, because `a` may come from schema1 or schema2.
-fn ensure_any_column_reference_is_unambiguous(
- expr: &Expr,
- schemas: &[DFSchemaRef],
-) -> Result<()> {
- if schemas.len() == 1 {
- return Ok(());
- }
- // all referenced columns in the expression that don't have relation
- let referenced_cols = expr.to_columns()?;
- let mut no_relation_cols = referenced_cols
- .iter()
- .filter_map(|col| {
- if col.relation.is_none() {
- Some((col.name.as_str(), 0))
- } else {
- None
- }
- })
- .collect::<HashMap<&str, u8>>();
- // find the name of the column existing in multi schemas.
- let ambiguous_col_name = schemas
- .iter()
- .flat_map(|schema| schema.fields())
- .map(|field| field.name())
- .find(|col_name| {
- no_relation_cols.entry(col_name).and_modify(|v| *v += 1);
- matches!(
- no_relation_cols.get_key_value(col_name.as_str()),
- Some((_, 2..))
- )
- });
-
- if let Some(col_name) = ambiguous_col_name {
- let maybe_field = schemas
- .iter()
- .flat_map(|schema| {
- schema
- .field_with_unqualified_name(col_name)
- .map(|f| f.qualified_name())
- .ok()
- })
- .collect::<Vec<_>>();
- Err(DataFusionError::Plan(format!(
- "reference \'{}\' is ambiguous, could be {};",
- col_name,
- maybe_field.join(","),
- )))
- } else {
- Ok(())
- }
-}
diff --git a/docs/source/user-guide/dataframe.md
b/docs/source/user-guide/dataframe.md
index 5ba803fce..c7d490e40 100644
--- a/docs/source/user-guide/dataframe.md
+++ b/docs/source/user-guide/dataframe.md
@@ -68,6 +68,7 @@ execution. The plan is evaluated (executed) when an action
method is invoked, su
| filter | Filter a DataFrame to only include rows that match the
specified filter expression.
|
| intersect | Calculate the intersection of two DataFrames. The two
DataFrames must have exactly the same schema
|
| join | Join this DataFrame with another DataFrame using the
specified columns as join keys.
|
+| join_on | Join this DataFrame with another DataFrame using
arbitrary expressions.
|
| limit | Limit the number of rows returned from this DataFrame.
|
| repartition | Repartition a DataFrame based on a logical
partitioning scheme.
|
| sort | Sort the DataFrame by the specified sorting
expressions. Any expression can be turned into a sort expression by calling its
`sort` method. |