This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8d36529ab Refactor: Change equijoin keys from column to expression in
logical join (#4602)
8d36529ab is described below
commit 8d36529abab500f4effd805e204b78ff0ab03d17
Author: ygf11 <[email protected]>
AuthorDate: Sat Dec 17 19:06:40 2022 +0800
Refactor: Change equijoin keys from column to expression in logical join
(#4602)
* Refactor: Make equi-join keys from column to expr in logical join plan
* resolve conflicts
* add tests
* remove unused code
* fix test
* add test
* remove unused code
* fmt
* clippy
* resolve merge conflict
* Update datafusion/expr/src/logical_plan/builder.rs
Co-authored-by: Andrew Lamb <[email protected]>
* rename join_keys to equi_exprs
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/physical_plan/planner.rs | 78 ++++-
datafusion/core/tests/sql/joins.rs | 368 +++++++++++++++++++--
datafusion/core/tests/sql/mod.rs | 4 +-
datafusion/expr/src/logical_plan/builder.rs | 112 ++++++-
datafusion/expr/src/logical_plan/plan.rs | 59 +++-
datafusion/optimizer/src/eliminate_cross_join.rs | 84 ++---
datafusion/optimizer/src/filter_null_join_keys.rs | 130 +++++---
datafusion/optimizer/src/push_down_filter.rs | 12 +-
datafusion/optimizer/src/push_down_projection.rs | 4 +-
.../optimizer/src/subquery_filter_to_join.rs | 2 +-
datafusion/proto/proto/datafusion.proto | 4 +-
datafusion/proto/src/generated/pbjson.rs | 52 +--
datafusion/proto/src/generated/prost.rs | 4 +-
datafusion/proto/src/logical_plan.rs | 47 ++-
datafusion/sql/src/planner.rs | 131 +++-----
15 files changed, 807 insertions(+), 284 deletions(-)
diff --git a/datafusion/core/src/physical_plan/planner.rs
b/datafusion/core/src/physical_plan/planner.rs
index 5ba8cabe9..53c88d5e9 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -65,6 +65,8 @@ use datafusion_expr::expr::{
Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like, TryCast,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
+use datafusion_expr::logical_plan;
+use
datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{WindowFrame, WindowFrameBound};
use datafusion_optimizer::utils::unalias;
@@ -846,8 +848,78 @@ impl DefaultPhysicalPlanner {
filter,
join_type,
null_equals_null,
+ schema: join_schema,
..
}) => {
+ // If join has expression equijoin keys, add physical
projecton.
+ let has_expr_join_key = keys.iter().any(|(l, r)| {
+ !(matches!(l, Expr::Column(_))
+ && matches!(r, Expr::Column(_)))
+ });
+ if has_expr_join_key {
+ let left_keys = keys
+ .iter()
+ .map(|(l, _r)| l)
+ .cloned()
+ .collect::<Vec<_>>();
+ let right_keys = keys
+ .iter()
+ .map(|(_l, r)| r)
+ .cloned()
+ .collect::<Vec<_>>();
+ let (left, right, column_on, added_project) = {
+ let (left, left_col_keys, left_projected) =
+ wrap_projection_for_join_if_necessary(
+ left_keys.as_slice(),
+ left.as_ref().clone(),
+ )?;
+ let (right, right_col_keys, right_projected) =
+ wrap_projection_for_join_if_necessary(
+ &right_keys,
+ right.as_ref().clone(),
+ )?;
+ (
+ left,
+ right,
+ (left_col_keys, right_col_keys),
+ left_projected || right_projected,
+ )
+ };
+
+ let join_plan =
+ LogicalPlan::Join(Join::try_new_with_project_input(
+ logical_plan,
+ Arc::new(left),
+ Arc::new(right),
+ column_on,
+ )?);
+
+ // Remove temporary projected columns
+ let join_plan = if added_project {
+ let final_join_result = join_schema
+ .fields()
+ .iter()
+ .map(|field| {
+ Expr::Column(field.qualified_column())
+ })
+ .collect::<Vec<_>>();
+ let projection =
+ logical_plan::Projection::try_new_with_schema(
+ final_join_result,
+ Arc::new(join_plan),
+ join_schema.clone(),
+ )?;
+ LogicalPlan::Projection(projection)
+ } else {
+ join_plan
+ };
+
+ return self
+ .create_initial_plan(&join_plan, session_state)
+ .await;
+ }
+
+ // All equi-join keys are columns now, create physical
join plan
let left_df_schema = left.schema();
let physical_left = self.create_initial_plan(left,
session_state).await?;
let right_df_schema = right.schema();
@@ -855,9 +927,11 @@ impl DefaultPhysicalPlanner {
let join_on = keys
.iter()
.map(|(l, r)| {
+ let l = l.try_into_col()?;
+ let r = r.try_into_col()?;
Ok((
- Column::new(&l.name,
left_df_schema.index_of_column(l)?),
- Column::new(&r.name,
right_df_schema.index_of_column(r)?),
+ Column::new(&l.name,
left_df_schema.index_of_column(&l)?),
+ Column::new(&r.name,
right_df_schema.index_of_column(&r)?),
))
})
.collect::<Result<join_utils::JoinOn>>()?;
diff --git a/datafusion/core/tests/sql/joins.rs
b/datafusion/core/tests/sql/joins.rs
index 1094a818f..e3f92cbb3 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -2322,12 +2322,11 @@ async fn reduce_cross_join_with_expr_join_key_all() ->
Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id,
t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Inner Join: t1.t1_id + Int64(12) = t2.t2_id + Int64(1)
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(12):Int64;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id + Int64(1):Int64;N]",
- " Projection: t1.t1_id, t1.t1_name, t1.t1_int, CAST(t1.t1_id
AS Int64) + Int64(12) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t1.t1_id + Int64(12):Int64;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Projection: t2.t2_id, t2.t2_name, t2.t2_int, CAST(t2.t2_id
AS Int64) + Int64(1) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, t2.t2_id
+ Int64(1):Int64;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) + Int64(12) =
CAST(t2.t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N,
t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
+
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -2367,15 +2366,13 @@ async fn reduce_cross_join_with_cast_expr_join_key() ->
Result<()> {
let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
- "Explain [plan_type:Utf8, plan:Utf8]",
- " Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N,
t2_id:UInt32;N, t1_name:Utf8;N]",
- " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N,
t1_name:Utf8;N, t2_id:UInt32;N]",
- " Inner Join: t1.t1_id + Int64(11) = CAST(t2.t2_id AS Int64)
[t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N, t2_id:UInt32;N,
CAST(t2.t2_id AS Int64):Int64;N]",
- " Projection: t1.t1_id, t1.t1_name, CAST(t1.t1_id AS Int64)
+ Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N, t1.t1_id + Int64(11):Int64;N]",
- " TableScan: t1 projection=[t1_id, t1_name]
[t1_id:UInt32;N, t1_name:Utf8;N]",
- " Projection: t2.t2_id, CAST(t2.t2_id AS Int64) AS
CAST(t2.t2_id AS Int64) [t2_id:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
- " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: t1.t1_id, t2.t2_id, t1.t1_name [t1_id:UInt32;N,
t2_id:UInt32;N, t1_name:Utf8;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) =
CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N,
t1_name:Utf8;N]",
+ " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];
+
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
@@ -2407,6 +2404,8 @@ async fn reduce_cross_join_with_wildcard_and_expr() ->
Result<()> {
let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
let sql = "select *,t1.t1_id+11 from t1,t2 where t1.t1_id+11=t2.t2_id";
+
+ // assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
@@ -2417,12 +2416,9 @@ async fn reduce_cross_join_with_wildcard_and_expr() ->
Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id,
t2.t2_name, t2.t2_int, CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N,
t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N,
t2_int:UInt32;N, t1.t1_id + Int64(11):Int64;N]",
- " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id,
t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Inner Join: t1.t1_id + Int64(11) = CAST(t2.t2_id AS Int64)
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t1.t1_id + Int64(11):Int64;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N, CAST(t2.t2_id AS
Int64):Int64;N]",
- " Projection: t1.t1_id, t1.t1_name, t1.t1_int,
CAST(t1.t1_id AS Int64) + Int64(11) [t1_id:UInt32;N, t1_name:Utf8;N,
t1_int:UInt32;N, t1.t1_id + Int64(11):Int64;N]",
- " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Projection: t2.t2_id, t2.t2_name, t2.t2_int,
CAST(t2.t2_id AS Int64) AS CAST(t2.t2_id AS Int64) [t2_id:UInt32;N,
t2_name:Utf8;N, t2_int:UInt32;N, CAST(t2.t2_id AS Int64):Int64;N]",
- " TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) =
CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N,
t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " TableScan: t1 projection=[t1_id, t1_name, t1_int]
[t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " TableScan: t2 projection=[t2_id, t2_name, t2_int]
[t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]"
];
let formatted = plan.display_indent_schema().to_string();
@@ -2432,6 +2428,54 @@ async fn reduce_cross_join_with_wildcard_and_expr() ->
Result<()> {
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);
+
+ // assert physical plan
+ let msg = format!("Creating physical plan for '{}'", sql);
+ let plan = ctx.create_logical_plan(sql).expect(&msg);
+ let state = ctx.state();
+ let logical_plan = state.optimize(&plan)?;
+ let physical_plan = state.create_physical_plan(&logical_plan).await?;
+ let expected = if repartition_joins {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name,
t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int,
CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6
as t2_int]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=Partitioned, join_type=Inner,
on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 3 }, Column { name:
\"CAST(t2.t2_id AS Int64)\", index: 3 })]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t1.t1_id + Int64(11)\", index: 3 }], 2)",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1
as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id +
Int64(11)]",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"CAST(t2.t2_id AS Int64)\", index: 3 }], 2)",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1
as t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(t2.t2_id AS
Int64)]",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ ]
+ } else {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name,
t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as t2_int,
CAST(t1_id@0 AS Int64) + 11 as t1.t1_id + Int64(11)]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t1_int@2 as t1_int, t2_id@4 as t2_id, t2_name@5 as t2_name, t2_int@6
as t2_int]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=CollectLeft, join_type=Inner,
on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 3 }, Column { name:
\"CAST(t2.t2_id AS Int64)\", index: 3 })]",
+ " CoalescePartitionsExec",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1
as t1_name, t1_int@2 as t1_int, CAST(t1_id@0 AS Int64) + 11 as t1.t1_id +
Int64(11)]",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as
t2_name, t2_int@2 as t2_int, CAST(t2_id@0 AS Int64) as CAST(t2.t2_id AS
Int64)]",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ ]
+ };
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ // assert execution result
let expected = vec![
"+-------+---------+--------+-------+---------+--------+----------------------+",
"| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int | t1.t1_id
+ Int64(11) |",
@@ -2448,3 +2492,289 @@ async fn reduce_cross_join_with_wildcard_and_expr() ->
Result<()> {
Ok(())
}
+
+#[tokio::test]
+async fn both_side_expr_key_inner_join() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \
+ FROM t1 \
+ INNER JOIN t2 \
+ ON t1.t1_id + cast(12 as INT UNSIGNED) = t2.t2_id +
cast(1 as INT UNSIGNED)";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx.create_logical_plan(sql).expect(&msg);
+ let state = ctx.state();
+ let logical_plan = state.optimize(&plan)?;
+ let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
+ let expected = if repartition_joins {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id,
t1_name@1 as t1_name]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t2_id@3 as t2_id]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=Partitioned, join_type=Inner,
on=[(Column { name: \"t1.t1_id + Int64(12)\", index: 2 }, Column { name:
\"t2.t2_id + Int64(1)\", index: 1 })]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t1.t1_id + Int64(12)\", index: 2 }], 2)",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1
as t1_name, t1_id@0 + CAST(12 AS UInt32) as t1.t1_id + Int64(12)]",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t2.t2_id + Int64(1)\", index: 1 }], 2)",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 +
CAST(1 AS UInt32) as t2.t2_id + Int64(1)]",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ ]
+ } else {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id,
t1_name@1 as t1_name]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t2_id@3 as t2_id]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=CollectLeft, join_type=Inner,
on=[(Column { name: \"t1.t1_id + Int64(12)\", index: 2 }, Column { name:
\"t2.t2_id + Int64(1)\", index: 1 })]",
+ " CoalescePartitionsExec",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1
as t1_name, t1_id@0 + CAST(12 AS UInt32) as t1.t1_id + Int64(12)]",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 +
CAST(1 AS UInt32) as t2.t2_id + Int64(1)]",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ ]
+ };
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+-------+---------+",
+ "| t1_id | t2_id | t1_name |",
+ "+-------+-------+---------+",
+ "| 11 | 22 | a |",
+ "| 33 | 44 | c |",
+ "| 44 | 55 | d |",
+ "+-------+-------+---------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn left_side_expr_key_inner_join() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \
+ FROM t1 \
+ INNER JOIN t2 \
+ ON t1.t1_id + cast(11 as INT UNSIGNED) = t2.t2_id";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx.create_logical_plan(sql).expect(&msg);
+ let state = ctx.state();
+ let logical_plan = state.optimize(&plan)?;
+ let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
+ let expected = if repartition_joins {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id,
t1_name@1 as t1_name]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t2_id@3 as t2_id]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=Partitioned, join_type=Inner,
on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 2 }, Column { name:
\"t2_id\", index: 0 })]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t1.t1_id + Int64(11)\", index: 2 }], 2)",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1
as t1_name, t1_id@0 + CAST(11 AS UInt32) as t1.t1_id + Int64(11)]",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t2_id\", index: 0 }], 2)",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ ]
+ } else {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id,
t1_name@1 as t1_name]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t2_id@3 as t2_id]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " HashJoinExec: mode=CollectLeft, join_type=Inner,
on=[(Column { name: \"t1.t1_id + Int64(11)\", index: 2 }, Column { name:
\"t2_id\", index: 0 })]",
+ " CoalescePartitionsExec",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1
as t1_name, t1_id@0 + CAST(11 AS UInt32) as t1.t1_id + Int64(11)]",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ ]
+ };
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+-------+---------+",
+ "| t1_id | t2_id | t1_name |",
+ "+-------+-------+---------+",
+ "| 11 | 22 | a |",
+ "| 33 | 44 | c |",
+ "| 44 | 55 | d |",
+ "+-------+-------+---------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn right_side_expr_key_inner_join() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ let sql = "SELECT t1.t1_id, t2.t2_id, t1.t1_name \
+ FROM t1 \
+ INNER JOIN t2 \
+ ON t1.t1_id = t2.t2_id - cast(11 as INT UNSIGNED)";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx.create_logical_plan(sql).expect(&msg);
+ let state = ctx.state();
+ let logical_plan = state.optimize(&plan)?;
+ let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
+ let expected = if repartition_joins {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id,
t1_name@1 as t1_name]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t2_id@2 as t2_id]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=Partitioned, join_type=Inner,
on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id -
Int64(11)\", index: 1 })]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t1_id\", index: 0 }], 2)",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t2.t2_id - Int64(11)\", index: 1 }], 2)",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 -
CAST(11 AS UInt32) as t2.t2_id - Int64(11)]",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ ]
+ } else {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t2_id@2 as t2_id,
t1_name@1 as t1_name]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t2_id@2 as t2_id]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=CollectLeft, join_type=Inner,
on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id -
Int64(11)\", index: 1 })]",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_id@0 -
CAST(11 AS UInt32) as t2.t2_id - Int64(11)]",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ ]
+ };
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+-------+---------+",
+ "| t1_id | t2_id | t1_name |",
+ "+-------+-------+---------+",
+ "| 11 | 22 | a |",
+ "| 33 | 44 | c |",
+ "| 44 | 55 | d |",
+ "+-------+-------+---------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
+
+#[tokio::test]
+async fn select_wildcard_with_expr_key_inner_join() -> Result<()> {
+ let test_repartition_joins = vec![true, false];
+ for repartition_joins in test_repartition_joins {
+ let ctx = create_join_context("t1_id", "t2_id", repartition_joins)?;
+
+ let sql = "SELECT * \
+ FROM t1 \
+ INNER JOIN t2 \
+ ON t1.t1_id = t2.t2_id - cast(11 as INT UNSIGNED)";
+
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx.create_logical_plan(sql).expect(&msg);
+ let state = ctx.state();
+ let logical_plan = state.optimize(&plan)?;
+ let physical_plan = state.create_physical_plan(&logical_plan).await?;
+
+ let expected = if repartition_joins {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name,
t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as
t2_int]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5
as t2_int]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=Partitioned, join_type=Inner,
on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id -
Int64(11)\", index: 3 })]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t1_id\", index: 0 }], 2)",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " RepartitionExec: partitioning=Hash([Column { name:
\"t2.t2_id - Int64(11)\", index: 3 }], 2)",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1
as t2_name, t2_int@2 as t2_int, t2_id@0 - CAST(11 AS UInt32) as t2.t2_id -
Int64(11)]",
+ " RepartitionExec:
partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1,
partition_sizes=[1]",
+ ]
+ } else {
+ vec![
+ "ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as t1_name,
t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5 as
t2_int]",
+ " ProjectionExec: expr=[t1_id@0 as t1_id, t1_name@1 as
t1_name, t1_int@2 as t1_int, t2_id@3 as t2_id, t2_name@4 as t2_name, t2_int@5
as t2_int]",
+ " CoalesceBatchesExec: target_batch_size=4096",
+ " HashJoinExec: mode=CollectLeft, join_type=Inner,
on=[(Column { name: \"t1_id\", index: 0 }, Column { name: \"t2.t2_id -
Int64(11)\", index: 3 })]",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ " ProjectionExec: expr=[t2_id@0 as t2_id, t2_name@1 as
t2_name, t2_int@2 as t2_int, t2_id@0 - CAST(11 AS UInt32) as t2.t2_id -
Int64(11)]",
+ " RepartitionExec: partitioning=RoundRobinBatch(2)",
+ " MemoryExec: partitions=1, partition_sizes=[1]",
+ ]
+ };
+ let formatted =
displayable(physical_plan.as_ref()).indent().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+
+ let expected = vec![
+ "+-------+---------+--------+-------+---------+--------+",
+ "| t1_id | t1_name | t1_int | t2_id | t2_name | t2_int |",
+ "+-------+---------+--------+-------+---------+--------+",
+ "| 11 | a | 1 | 22 | y | 1 |",
+ "| 33 | c | 3 | 44 | x | 3 |",
+ "| 44 | d | 4 | 55 | w | 3 |",
+ "+-------+---------+--------+-------+---------+--------+",
+ ];
+
+ let results = execute_to_batches(&ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &results);
+ }
+
+ Ok(())
+}
diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs
index 1075041c2..d97eb04bd 100644
--- a/datafusion/core/tests/sql/mod.rs
+++ b/datafusion/core/tests/sql/mod.rs
@@ -185,7 +185,9 @@ fn create_join_context(
repartition_joins: bool,
) -> Result<SessionContext> {
let ctx = SessionContext::with_config(
- SessionConfig::new().with_repartition_joins(repartition_joins),
+ SessionConfig::new()
+ .with_repartition_joins(repartition_joins)
+ .with_target_partitions(2),
);
let t1_schema = Arc::new(Schema::new(vec![
diff --git a/datafusion/expr/src/logical_plan/builder.rs
b/datafusion/expr/src/logical_plan/builder.rs
index 2c16a3ca7..f0a9c9e5b 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -18,7 +18,8 @@
//! This module provides a builder for creating LogicalPlans
use crate::expr_rewriter::{
- coerce_plan_expr_for_schema, normalize_col, normalize_cols,
rewrite_sort_cols_by_aggs,
+ coerce_plan_expr_for_schema, normalize_col, normalize_col_with_schemas,
+ normalize_cols, rewrite_sort_cols_by_aggs,
};
use crate::type_coercion::binary::comparison_coercion;
use crate::utils::{columnize_expr, exprlist_to_fields, from_plan};
@@ -31,8 +32,8 @@ use crate::{
Union, Values, Window,
},
utils::{
- can_hash, expand_qualified_wildcard, expand_wildcard,
- group_window_expr_by_sort_keys,
+ can_hash, check_all_column_from_schema, expand_qualified_wildcard,
+ expand_wildcard, group_window_expr_by_sort_keys,
},
Expr, ExprSchemable, TableSource,
};
@@ -555,7 +556,11 @@ impl LogicalPlanBuilder {
let left_keys =
left_keys.into_iter().collect::<Result<Vec<Column>>>()?;
let right_keys =
right_keys.into_iter().collect::<Result<Vec<Column>>>()?;
- let on: Vec<(_, _)> =
left_keys.into_iter().zip(right_keys.into_iter()).collect();
+ let on = left_keys
+ .into_iter()
+ .zip(right_keys.into_iter())
+ .map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
+ .collect();
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
@@ -591,19 +596,19 @@ impl LogicalPlanBuilder {
let on: Vec<(_, _)> =
left_keys.into_iter().zip(right_keys.into_iter()).collect();
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
- let mut join_on: Vec<(Column, Column)> = vec![];
+ let mut join_on: Vec<(Expr, Expr)> = vec![];
let mut filters: Option<Expr> = None;
for (l, r) in &on {
if self.plan.schema().field_from_column(l).is_ok()
&& right.schema().field_from_column(r).is_ok()
&&
can_hash(self.plan.schema().field_from_column(l)?.data_type())
{
- join_on.push((l.clone(), r.clone()));
+ join_on.push((Expr::Column(l.clone()),
Expr::Column(r.clone())));
} else if self.plan.schema().field_from_column(r).is_ok()
&& right.schema().field_from_column(l).is_ok()
&&
can_hash(self.plan.schema().field_from_column(r)?.data_type())
{
- join_on.push((r.clone(), l.clone()));
+ join_on.push((Expr::Column(r.clone()),
Expr::Column(l.clone())));
} else {
let expr = binary_expr(
Expr::Column(l.clone()),
@@ -616,6 +621,7 @@ impl LogicalPlanBuilder {
}
}
}
+
if join_on.is_empty() {
let join = Self::from(self.plan).cross_join(right)?;
join.filter(filters.ok_or_else(|| {
@@ -791,6 +797,98 @@ impl LogicalPlanBuilder {
pub fn build(self) -> Result<LogicalPlan> {
Ok(self.plan)
}
+
+ /// Apply a join with the expression on constraint.
+ ///
+ /// equi_exprs are "equijoin" predicates expressions on the existing and
right inputs, respectively.
+ ///
+ /// filter: any other filter expression to apply during the join.
equi_exprs predicates are likely
+ /// to be evaluated more quickly than the filter expressions
+ pub fn join_with_expr_keys(
+ self,
+ right: LogicalPlan,
+ join_type: JoinType,
+ equi_exprs: (Vec<impl Into<Expr>>, Vec<impl Into<Expr>>),
+ filter: Option<Expr>,
+ ) -> Result<Self> {
+ if equi_exprs.0.len() != equi_exprs.1.len() {
+ return Err(DataFusionError::Plan(
+ "left_keys and right_keys were not the same
length".to_string(),
+ ));
+ }
+
+ let join_key_pairs = equi_exprs
+ .0
+ .into_iter()
+ .zip(equi_exprs.1.into_iter())
+ .map(|(l, r)| {
+ let left_key = l.into();
+ let right_key = r.into();
+
+ let left_using_columns = left_key.to_columns()?;
+ let normalized_left_key = normalize_col_with_schemas(
+ left_key,
+ &[self.plan.schema(), right.schema()],
+ &[left_using_columns],
+ )?;
+
+ let right_using_columns = right_key.to_columns()?;
+ let normalized_right_key = normalize_col_with_schemas(
+ right_key,
+ &[self.plan.schema(), right.schema()],
+ &[right_using_columns],
+ )?;
+
+ let normalized_left_using_columns =
normalized_left_key.to_columns()?;
+ let l_is_left = check_all_column_from_schema(
+ &normalized_left_using_columns,
+ self.plan.schema().clone(),
+ )?;
+
+ let normalized_right_using_columns =
normalized_right_key.to_columns()?;
+ let r_is_right = check_all_column_from_schema(
+ &normalized_right_using_columns,
+ right.schema().clone(),
+ )?;
+
+ let r_is_left_and_l_is_right = || {
+ let result = check_all_column_from_schema(
+ &normalized_right_using_columns,
+ self.plan.schema().clone(),
+ )? && check_all_column_from_schema(
+ &normalized_left_using_columns,
+ right.schema().clone(),
+ )?;
+ Result::Ok(result)
+ };
+
+ if l_is_left && r_is_right {
+ Ok((normalized_left_key, normalized_right_key))
+ } else if r_is_left_and_l_is_right()?{
+ Ok((normalized_right_key, normalized_left_key))
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "can't create join plan, join key should belong to one
input, error key: ({},{})",
+ normalized_left_key, normalized_right_key
+ )))
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let join_schema =
+ build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
+
+ Ok(Self::from(LogicalPlan::Join(Join {
+ left: Arc::new(self.plan),
+ right: Arc::new(right),
+ on: join_key_pairs,
+ filter,
+ join_type,
+ join_constraint: JoinConstraint::On,
+ schema: DFSchemaRef::new(join_schema),
+ null_equals_null: false,
+ })))
+ }
}
/// Creates a schema for a join operation.
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index f1eb3f135..9368b1ec4 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -25,7 +25,9 @@ use crate::utils::{
self, exprlist_to_fields, from_plan, grouping_set_expr_count,
grouping_set_to_exprlist,
};
-use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource};
+use crate::{
+ build_join_schema, Expr, ExprSchemable, TableProviderFilterPushDown,
TableSource,
+};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::{
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
OwnedTableReference,
@@ -252,7 +254,7 @@ impl LogicalPlan {
}) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
LogicalPlan::Join(Join { on, filter, .. }) => on
.iter()
- .flat_map(|(l, r)| vec![Expr::Column(l.clone()),
Expr::Column(r.clone())])
+ .flat_map(|(l, r)| vec![l.clone(), r.clone()])
.chain(
filter
.as_ref()
@@ -343,12 +345,14 @@ impl LogicalPlan {
..
}) = plan
{
- self.using_columns.push(
- on.iter()
- .flat_map(|entry| [&entry.0, &entry.1])
- .cloned()
- .collect::<HashSet<Column>>(),
- );
+ // The join keys in using-join must be columns.
+ let columns =
+ on.iter().try_fold(HashSet::new(), |mut accumu, (l,
r)| {
+ accumu.insert(l.try_into_col()?);
+ accumu.insert(r.try_into_col()?);
+ Result::<_, DataFusionError>::Ok(accumu)
+ })?;
+ self.using_columns.push(columns);
}
Ok(true)
}
@@ -1646,8 +1650,8 @@ pub struct Join {
pub left: Arc<LogicalPlan>,
/// Right input
pub right: Arc<LogicalPlan>,
- /// Equijoin clause expressed as pairs of (left, right) join columns
- pub on: Vec<(Column, Column)>,
+ /// Equijoin clause expressed as pairs of (left, right) join expressions
+ pub on: Vec<(Expr, Expr)>,
/// Filters applied during join (non-equi conditions)
pub filter: Option<Expr>,
/// Join type
@@ -1660,6 +1664,41 @@ pub struct Join {
pub null_equals_null: bool,
}
+impl Join {
+ /// Create Join with input which wrapped with projection, this method is
used to help create physical join.
+ pub fn try_new_with_project_input(
+ original: &LogicalPlan,
+ left: Arc<LogicalPlan>,
+ right: Arc<LogicalPlan>,
+ column_on: (Vec<Column>, Vec<Column>),
+ ) -> Result<Self, DataFusionError> {
+ let original_join = match original {
+ LogicalPlan::Join(join) => join,
+ _ => return plan_err!("Could not create join with project input"),
+ };
+
+ let on: Vec<(Expr, Expr)> = column_on
+ .0
+ .into_iter()
+ .zip(column_on.1.into_iter())
+ .map(|(l, r)| (Expr::Column(l), Expr::Column(r)))
+ .collect();
+ let join_schema =
+ build_join_schema(left.schema(), right.schema(),
&original_join.join_type)?;
+
+ Ok(Join {
+ left,
+ right,
+ on,
+ filter: original_join.filter.clone(),
+ join_type: original_join.join_type,
+ join_constraint: original_join.join_constraint,
+ schema: Arc::new(join_schema),
+ null_equals_null: original_join.null_equals_null,
+ })
+ }
+}
+
/// Subquery
#[derive(Clone)]
pub struct Subquery {
diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs
b/datafusion/optimizer/src/eliminate_cross_join.rs
index 338e58f68..6aaf02d26 100644
--- a/datafusion/optimizer/src/eliminate_cross_join.rs
+++ b/datafusion/optimizer/src/eliminate_cross_join.rs
@@ -19,18 +19,14 @@
use std::collections::HashSet;
use std::sync::Arc;
+use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::expr::{BinaryExpr, Expr};
use datafusion_expr::logical_plan::{
CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
};
use datafusion_expr::utils::{can_hash, check_all_column_from_schema};
-use datafusion_expr::{
- and, build_join_schema, or, wrap_projection_for_join_if_necessary,
ExprSchemable,
- Operator,
-};
-
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use datafusion_expr::{and, build_join_schema, or, ExprSchemable, Operator};
#[derive(Default)]
pub struct EliminateCrossJoin;
@@ -142,11 +138,7 @@ fn flatten_join_inputs(
) -> Result<()> {
let children = match plan {
LogicalPlan::Join(join) => {
- for join_keys in join.on.iter() {
- let join_keys = join_keys.clone();
- possible_join_keys
- .push((Expr::Column(join_keys.0),
Expr::Column(join_keys.1)));
- }
+ possible_join_keys.extend(join.on.clone());
let left = &*(join.left);
let right = &*(join.right);
Ok::<Vec<&LogicalPlan>, DataFusionError>(vec![left, right])
@@ -239,26 +231,12 @@ fn find_inner_join(
&JoinType::Inner,
)?);
- // Wrap projection
- let (left_on, right_on): (Vec<Expr>, Vec<Expr>) =
- join_keys.into_iter().unzip();
- let (new_left_input, new_left_on, _) =
- wrap_projection_for_join_if_necessary(&left_on,
left_input.clone())?;
- let (new_right_input, new_right_on, _) =
- wrap_projection_for_join_if_necessary(&right_on, right_input)?;
-
- // Build new join on
- let join_on = new_left_on
- .into_iter()
- .zip(new_right_on.into_iter())
- .collect::<Vec<_>>();
-
return Ok(LogicalPlan::Join(Join {
- left: Arc::new(new_left_input),
- right: Arc::new(new_right_input),
+ left: Arc::new(left_input.clone()),
+ right: Arc::new(right_input),
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
- on: join_on,
+ on: join_keys,
filter: None,
schema: join_schema,
null_equals_null: false,
@@ -1108,14 +1086,10 @@ mod tests {
.build()?;
let expected = vec![
- "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32,
a:UInt32, b:UInt32, c:UInt32]",
- " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2)
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32,
c:UInt32, t2.a * UInt32(2):UInt32]",
- " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100)
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
- " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
- " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32,
b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
- " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
- ];
+ "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32,
a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"];
assert_optimized_plan_eq(&plan, expected);
@@ -1167,13 +1141,10 @@ mod tests {
.build()?;
let expected = vec![
- "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2)
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32,
c:UInt32, t2.a * UInt32(2):UInt32]",
- " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100)
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
- " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
- " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2)
[a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
- " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(&plan, expected);
@@ -1198,14 +1169,11 @@ mod tests {
.build()?;
let expected = vec![
- "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
- " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2)
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32, a:UInt32, b:UInt32,
c:UInt32, t2.a * UInt32(2):UInt32]",
- " Projection: t1.a, t1.b, t1.c, t1.a + UInt32(100)
[a:UInt32, b:UInt32, c:UInt32, t1.a + UInt32(100):UInt32]",
- " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
- " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2)
[a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
- " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
- ];
+ "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32,
c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ ];
assert_optimized_plan_eq(&plan, expected);
@@ -1240,15 +1208,11 @@ mod tests {
let expected = vec![
"Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32,
c:UInt32]",
" Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b,
t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32,
b:UInt32, c:UInt32]",
- " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32,
a:UInt32, b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
- " Projection: t1.a, t1.b, t1.c, t3.a, t3.b, t3.c, t3.a +
UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, t3.a +
UInt32(100):UInt32]",
- " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100)
[a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32, a:UInt32, b:UInt32,
c:UInt32, t3.a + UInt32(100):UInt32]",
- " Projection: t1.a, t1.b, t1.c, t1.a * UInt32(2)
[a:UInt32, b:UInt32, c:UInt32, t1.a * UInt32(2):UInt32]",
- " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
- " Projection: t3.a, t3.b, t3.c, t3.a + UInt32(100)
[a:UInt32, b:UInt32, c:UInt32, t3.a + UInt32(100):UInt32]",
- " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
- " Projection: t2.a, t2.b, t2.c, t2.a * UInt32(2) [a:UInt32,
b:UInt32, c:UInt32, t2.a * UInt32(2):UInt32]",
- " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
+ " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32,
c:UInt32]",
+ " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100)
[a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
+ " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
assert_optimized_plan_eq(&plan, expected);
diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs
b/datafusion/optimizer/src/filter_null_join_keys.rs
index 4bee88359..eea98ad1f 100644
--- a/datafusion/optimizer/src/filter_null_join_keys.rs
+++ b/datafusion/optimizer/src/filter_null_join_keys.rs
@@ -20,14 +20,12 @@
//! and then insert an `IsNotNull` filter on the nullable side since null
values
//! can never match.
-use std::sync::Arc;
-
-use datafusion_common::{Column, DFField, DFSchemaRef, Result};
+use crate::{utils, OptimizerConfig, OptimizerRule};
+use datafusion_common::Result;
use datafusion_expr::{
- and, logical_plan::Filter, logical_plan::JoinType, Expr, LogicalPlan,
+ and, logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable,
LogicalPlan,
};
-
-use crate::{utils, OptimizerConfig, OptimizerRule};
+use std::sync::Arc;
/// The FilterNullJoinKeys rule will identify inner joins with equi-join
conditions
/// where the join key is nullable on one side and non-nullable on the other
side
@@ -66,15 +64,12 @@ impl OptimizerRule for FilterNullJoinKeys {
let mut right_filters = vec![];
for (l, r) in &join.on {
- if let Some((left_field, right_field)) =
- resolve_join_key_pair(left_schema, right_schema, l, r)
- {
- if left_field.is_nullable() {
- left_filters.push(l.clone());
- }
- if right_field.is_nullable() {
- right_filters.push(r.clone());
- }
+ if l.nullable(left_schema)? {
+ left_filters.push(l.clone());
+ }
+
+ if r.nullable(right_schema)? {
+ right_filters.push(r.clone());
}
}
@@ -106,10 +101,10 @@ impl OptimizerRule for FilterNullJoinKeys {
}
}
-fn create_not_null_predicate(columns: Vec<Column>) -> Expr {
- let not_null_exprs: Vec<Expr> = columns
+fn create_not_null_predicate(filters: Vec<Expr>) -> Expr {
+ let not_null_exprs: Vec<Expr> = filters
.into_iter()
- .map(|c| Expr::IsNotNull(Box::new(Expr::Column(c))))
+ .map(|c| Expr::IsNotNull(Box::new(c)))
.collect();
// combine the IsNotNull expressions with AND
not_null_exprs
@@ -118,42 +113,13 @@ fn create_not_null_predicate(columns: Vec<Column>) ->
Expr {
.fold(not_null_exprs[0].clone(), |a, b| and(a, b.clone()))
}
-fn resolve_join_key_pair(
- left_schema: &DFSchemaRef,
- right_schema: &DFSchemaRef,
- c1: &Column,
- c2: &Column,
-) -> Option<(DFField, DFField)> {
- resolve_fields(left_schema, right_schema, c1, c2)
- .or_else(|| resolve_fields(left_schema, right_schema, c2, c1))
-}
-
-fn resolve_fields(
- left_schema: &DFSchemaRef,
- right_schema: &DFSchemaRef,
- c1: &Column,
- c2: &Column,
-) -> Option<(DFField, DFField)> {
- match (
- left_schema.index_of_column(c1),
- right_schema.index_of_column(c2),
- ) {
- (Ok(left_index), Ok(right_index)) => {
- let left_field = left_schema.field(left_index);
- let right_field = right_schema.field(right_index);
- Some((left_field.clone(), right_field.clone()))
- }
- _ => None,
- }
-}
-
#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{Column, Result};
use datafusion_expr::logical_plan::table_scan;
- use datafusion_expr::{logical_plan::JoinType, LogicalPlanBuilder};
+ use datafusion_expr::{col, lit, logical_plan::JoinType,
LogicalPlanBuilder};
use crate::optimizer::OptimizerContext;
@@ -234,6 +200,74 @@ mod tests {
Ok(())
}
+ #[test]
+ fn left_nullable_expr_key() -> Result<()> {
+ let (t1, t2) = test_tables()?;
+ let plan = LogicalPlanBuilder::from(t1)
+ .join_with_expr_keys(
+ t2,
+ JoinType::Inner,
+ (
+ vec![col("t1.optional_id") + lit(1u32)],
+ vec![col("t2.id") + lit(1u32)],
+ ),
+ None,
+ )?
+ .build()?;
+ let expected = "Inner Join: t1.optional_id + UInt32(1) = t2.id +
UInt32(1)\
+ \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\
+ \n TableScan: t1\
+ \n TableScan: t2";
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
+ #[test]
+ fn right_nullable_expr_key() -> Result<()> {
+ let (t1, t2) = test_tables()?;
+ let plan = LogicalPlanBuilder::from(t1)
+ .join_with_expr_keys(
+ t2,
+ JoinType::Inner,
+ (
+ vec![col("t1.id") + lit(1u32)],
+ vec![col("t2.optional_id") + lit(1u32)],
+ ),
+ None,
+ )?
+ .build()?;
+ let expected = "Inner Join: t1.id + UInt32(1) = t2.optional_id +
UInt32(1)\
+ \n TableScan: t1\
+ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\
+ \n TableScan: t2";
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
+ #[test]
+ fn both_side_nullable_expr_key() -> Result<()> {
+ let (t1, t2) = test_tables()?;
+ let plan = LogicalPlanBuilder::from(t1)
+ .join_with_expr_keys(
+ t2,
+ JoinType::Inner,
+ (
+ vec![col("t1.optional_id") + lit(1u32)],
+ vec![col("t2.optional_id") + lit(1u32)],
+ ),
+ None,
+ )?
+ .build()?;
+ let expected =
+ "Inner Join: t1.optional_id + UInt32(1) = t2.optional_id +
UInt32(1)\
+ \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\
+ \n TableScan: t1\
+ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\
+ \n TableScan: t2";
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
fn build_plan(
left_table: LogicalPlan,
right_table: LogicalPlan,
diff --git a/datafusion/optimizer/src/push_down_filter.rs
b/datafusion/optimizer/src/push_down_filter.rs
index 0cf6c635b..4ca14c278 100644
--- a/datafusion/optimizer/src/push_down_filter.rs
+++ b/datafusion/optimizer/src/push_down_filter.rs
@@ -457,8 +457,18 @@ fn push_down_join(
Err(e) => return Some(Err(e)),
};
+ // Only allow both side key is column.
+ let join_col_keys = join
+ .on
+ .iter()
+ .flat_map(|(l, r)| match (l.try_into_col(),
r.try_into_col()) {
+ (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
+ _ => None,
+ })
+ .collect::<Vec<_>>();
+
for col in columns.iter() {
- for (l, r) in join.on.iter() {
+ for (l, r) in join_col_keys.iter() {
if col == l {
join_cols_to_replace.insert(col, r);
break;
diff --git a/datafusion/optimizer/src/push_down_projection.rs
b/datafusion/optimizer/src/push_down_projection.rs
index 01b089e42..fb8e241fd 100644
--- a/datafusion/optimizer/src/push_down_projection.rs
+++ b/datafusion/optimizer/src/push_down_projection.rs
@@ -156,8 +156,8 @@ fn optimize_plan(
..
}) => {
for (l, r) in on {
- new_required_columns.insert(l.clone());
- new_required_columns.insert(r.clone());
+ new_required_columns.extend(l.to_columns()?);
+ new_required_columns.extend(r.to_columns()?);
}
if let Some(expr) = filter {
diff --git a/datafusion/optimizer/src/subquery_filter_to_join.rs
b/datafusion/optimizer/src/subquery_filter_to_join.rs
index 436d478b9..da7986956 100644
--- a/datafusion/optimizer/src/subquery_filter_to_join.rs
+++ b/datafusion/optimizer/src/subquery_filter_to_join.rs
@@ -132,7 +132,7 @@ impl OptimizerRule for SubqueryFilterToJoin {
Ok(LogicalPlan::Join(Join {
left: Arc::new(input),
right: Arc::new(right_input),
- on: vec![(left_key, right_key)],
+ on: vec![(Expr::Column(left_key),
Expr::Column(right_key))],
filter: None,
join_type,
join_constraint: JoinConstraint::On,
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 97ba57a7e..0b0124007 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -257,8 +257,8 @@ message JoinNode {
LogicalPlanNode right = 2;
JoinType join_type = 3;
JoinConstraint join_constraint = 4;
- repeated datafusion.Column left_join_column = 5;
- repeated datafusion.Column right_join_column = 6;
+ repeated LogicalExprNode left_join_key = 5;
+ repeated LogicalExprNode right_join_key = 6;
bool null_equals_null = 7;
LogicalExprNode filter = 8;
}
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index 13236a935..402048203 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -9591,10 +9591,10 @@ impl serde::Serialize for JoinNode {
if self.join_constraint != 0 {
len += 1;
}
- if !self.left_join_column.is_empty() {
+ if !self.left_join_key.is_empty() {
len += 1;
}
- if !self.right_join_column.is_empty() {
+ if !self.right_join_key.is_empty() {
len += 1;
}
if self.null_equals_null {
@@ -9620,11 +9620,11 @@ impl serde::Serialize for JoinNode {
.ok_or_else(|| serde::ser::Error::custom(format!("Invalid
variant {}", self.join_constraint)))?;
struct_ser.serialize_field("joinConstraint", &v)?;
}
- if !self.left_join_column.is_empty() {
- struct_ser.serialize_field("leftJoinColumn",
&self.left_join_column)?;
+ if !self.left_join_key.is_empty() {
+ struct_ser.serialize_field("leftJoinKey", &self.left_join_key)?;
}
- if !self.right_join_column.is_empty() {
- struct_ser.serialize_field("rightJoinColumn",
&self.right_join_column)?;
+ if !self.right_join_key.is_empty() {
+ struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?;
}
if self.null_equals_null {
struct_ser.serialize_field("nullEqualsNull",
&self.null_equals_null)?;
@@ -9648,10 +9648,10 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
"joinType",
"join_constraint",
"joinConstraint",
- "left_join_column",
- "leftJoinColumn",
- "right_join_column",
- "rightJoinColumn",
+ "left_join_key",
+ "leftJoinKey",
+ "right_join_key",
+ "rightJoinKey",
"null_equals_null",
"nullEqualsNull",
"filter",
@@ -9663,8 +9663,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
Right,
JoinType,
JoinConstraint,
- LeftJoinColumn,
- RightJoinColumn,
+ LeftJoinKey,
+ RightJoinKey,
NullEqualsNull,
Filter,
}
@@ -9692,8 +9692,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
"right" => Ok(GeneratedField::Right),
"joinType" | "join_type" =>
Ok(GeneratedField::JoinType),
"joinConstraint" | "join_constraint" =>
Ok(GeneratedField::JoinConstraint),
- "leftJoinColumn" | "left_join_column" =>
Ok(GeneratedField::LeftJoinColumn),
- "rightJoinColumn" | "right_join_column" =>
Ok(GeneratedField::RightJoinColumn),
+ "leftJoinKey" | "left_join_key" =>
Ok(GeneratedField::LeftJoinKey),
+ "rightJoinKey" | "right_join_key" =>
Ok(GeneratedField::RightJoinKey),
"nullEqualsNull" | "null_equals_null" =>
Ok(GeneratedField::NullEqualsNull),
"filter" => Ok(GeneratedField::Filter),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
@@ -9719,8 +9719,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
let mut right__ = None;
let mut join_type__ = None;
let mut join_constraint__ = None;
- let mut left_join_column__ = None;
- let mut right_join_column__ = None;
+ let mut left_join_key__ = None;
+ let mut right_join_key__ = None;
let mut null_equals_null__ = None;
let mut filter__ = None;
while let Some(k) = map.next_key()? {
@@ -9749,17 +9749,17 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
}
join_constraint__ =
Some(map.next_value::<JoinConstraint>()? as i32);
}
- GeneratedField::LeftJoinColumn => {
- if left_join_column__.is_some() {
- return
Err(serde::de::Error::duplicate_field("leftJoinColumn"));
+ GeneratedField::LeftJoinKey => {
+ if left_join_key__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("leftJoinKey"));
}
- left_join_column__ = Some(map.next_value()?);
+ left_join_key__ = Some(map.next_value()?);
}
- GeneratedField::RightJoinColumn => {
- if right_join_column__.is_some() {
- return
Err(serde::de::Error::duplicate_field("rightJoinColumn"));
+ GeneratedField::RightJoinKey => {
+ if right_join_key__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("rightJoinKey"));
}
- right_join_column__ = Some(map.next_value()?);
+ right_join_key__ = Some(map.next_value()?);
}
GeneratedField::NullEqualsNull => {
if null_equals_null__.is_some() {
@@ -9780,8 +9780,8 @@ impl<'de> serde::Deserialize<'de> for JoinNode {
right: right__,
join_type: join_type__.unwrap_or_default(),
join_constraint: join_constraint__.unwrap_or_default(),
- left_join_column: left_join_column__.unwrap_or_default(),
- right_join_column: right_join_column__.unwrap_or_default(),
+ left_join_key: left_join_key__.unwrap_or_default(),
+ right_join_key: right_join_key__.unwrap_or_default(),
null_equals_null: null_equals_null__.unwrap_or_default(),
filter: filter__,
})
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 1405e1eba..18bf2796f 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -364,9 +364,9 @@ pub struct JoinNode {
#[prost(enumeration = "JoinConstraint", tag = "4")]
pub join_constraint: i32,
#[prost(message, repeated, tag = "5")]
- pub left_join_column: ::prost::alloc::vec::Vec<Column>,
+ pub left_join_key: ::prost::alloc::vec::Vec<LogicalExprNode>,
#[prost(message, repeated, tag = "6")]
- pub right_join_column: ::prost::alloc::vec::Vec<Column>,
+ pub right_join_key: ::prost::alloc::vec::Vec<LogicalExprNode>,
#[prost(bool, tag = "7")]
pub null_equals_null: bool,
#[prost(message, optional, tag = "8")]
diff --git a/datafusion/proto/src/logical_plan.rs
b/datafusion/proto/src/logical_plan.rs
index 09ae9e41e..f2000d7a7 100644
--- a/datafusion/proto/src/logical_plan.rs
+++ b/datafusion/proto/src/logical_plan.rs
@@ -38,7 +38,7 @@ use datafusion::{
datasource::{provider_as_source, source_as_provider},
prelude::SessionContext,
};
-use datafusion_common::{context, Column, DataFusionError, OwnedTableReference};
+use datafusion_common::{context, DataFusionError, OwnedTableReference};
use datafusion_expr::logical_plan::{builder::project, Prepare};
use datafusion_expr::{
logical_plan::{
@@ -695,10 +695,16 @@ impl AsLogicalPlan for LogicalPlanNode {
LogicalPlanBuilder::from(input).limit(skip, fetch)?.build()
}
LogicalPlanType::Join(join) => {
- let left_keys: Vec<Column> =
- join.left_join_column.iter().map(|i| i.into()).collect();
- let right_keys: Vec<Column> =
- join.right_join_column.iter().map(|i| i.into()).collect();
+ let left_keys: Vec<Expr> = join
+ .left_join_key
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?;
+ let right_keys: Vec<Expr> = join
+ .right_join_key
+ .iter()
+ .map(|expr| parse_expr(expr, ctx))
+ .collect::<Result<Vec<_>, _>>()?;
let join_type =
protobuf::JoinType::from_i32(join.join_type).ok_or_else(||
{
proto_error(format!(
@@ -727,17 +733,24 @@ impl AsLogicalPlan for LogicalPlanNode {
extension_codec
)?);
let builder = match join_constraint.into() {
- JoinConstraint::On => builder.join(
+ JoinConstraint::On => builder.join_with_expr_keys(
into_logical_plan!(join.right, ctx, extension_codec)?,
join_type.into(),
(left_keys, right_keys),
filter,
)?,
- JoinConstraint::Using => builder.join_using(
- into_logical_plan!(join.right, ctx, extension_codec)?,
- join_type.into(),
- left_keys,
- )?,
+ JoinConstraint::Using => {
+ // The equijoin keys in using-join must be column.
+ let using_keys = left_keys
+ .into_iter()
+ .map(|key| key.try_into_col())
+ .collect::<Result<Vec<_>, _>>()?;
+ builder.join_using(
+ into_logical_plan!(join.right, ctx,
extension_codec)?,
+ join_type.into(),
+ using_keys,
+ )?
+ }
};
builder.build()
@@ -1100,8 +1113,12 @@ impl AsLogicalPlan for LogicalPlanNode {
right.as_ref(),
extension_codec,
)?;
- let (left_join_column, right_join_column) =
- on.iter().map(|(l, r)| (l.into(), r.into())).unzip();
+ let (left_join_key, right_join_key) = on
+ .iter()
+ .map(|(l, r)| Ok((l.try_into()?, r.try_into()?)))
+ .collect::<Result<Vec<_>, to_proto::Error>>()?
+ .into_iter()
+ .unzip();
let join_type: protobuf::JoinType =
join_type.to_owned().into();
let join_constraint: protobuf::JoinConstraint =
join_constraint.to_owned().into();
@@ -1116,8 +1133,8 @@ impl AsLogicalPlan for LogicalPlanNode {
right: Some(Box::new(right)),
join_type: join_type.into(),
join_constraint: join_constraint.into(),
- left_join_column,
- right_join_column,
+ left_join_key,
+ right_join_key,
null_equals_null: *null_equals_null,
filter,
},
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 761e13fda..7f12e9dec 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -47,9 +47,7 @@ use datafusion_expr::expr::{
};
use datafusion_expr::expr_rewriter::normalize_col;
use datafusion_expr::expr_rewriter::normalize_col_with_schemas;
-use datafusion_expr::logical_plan::builder::{
- project, wrap_projection_for_join_if_necessary,
-};
+use datafusion_expr::logical_plan::builder::project;
use datafusion_expr::logical_plan::Join as HashJoin;
use datafusion_expr::logical_plan::JoinConstraint as HashJoinConstraint;
use datafusion_expr::logical_plan::{
@@ -852,32 +850,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
join.build()
} else {
- // Wrap projection for left input if left join keys
contain normal expression.
- let (left_child, left_join_keys, left_projected) =
- wrap_projection_for_join_if_necessary(&left_keys,
left)?;
-
- // Wrap projection for right input if right join keys
contains normal expression.
- let (right_child, right_join_keys, right_projected) =
- wrap_projection_for_join_if_necessary(&right_keys,
right)?;
-
- let join_plan_builder =
LogicalPlanBuilder::from(left_child).join(
- right_child,
- join_type,
- (left_join_keys, right_join_keys),
- join_filter,
- )?;
-
- // Remove temporary projected columns if necessary.
- if left_projected || right_projected {
- let final_join_result = join_schema
- .fields()
- .iter()
- .map(|field|
Expr::Column(field.qualified_column()))
- .collect::<Vec<_>>();
- join_plan_builder.project(final_join_result)?.build()
- } else {
- join_plan_builder.build()
- }
+ LogicalPlanBuilder::from(left)
+ .join_with_expr_keys(
+ right,
+ join_type,
+ (left_keys, right_keys),
+ join_filter,
+ )?
+ .build()
}
}
JoinConstraint::Using(idents) => {
@@ -1045,8 +1025,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
{
// For query: select id from t1 join t2 using(id), this is legal.
// We should dedup the fields for cols in using clause.
- for join_cols in on.iter() {
- let left_field =
left.schema().field_from_column(&join_cols.0)?;
+ for join_keys in on.iter() {
+ let join_col = &join_keys.0.try_into_col()?;
+ let left_field = left.schema().field_from_column(join_col)?;
fields.retain(|field| {
field.unqualified_column().name
!= left_field.unqualified_column().name
@@ -6006,12 +5987,10 @@ mod tests {
ON orders.customer_id * 2 = person.id + 10";
let expected = "Projection: person.id, orders.order_id\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id + Int64(10) = orders.customer_id *
Int64(2)\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀, person.id
+ Int64(10)\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
orders.customer_id * Int64(2)\
- \n TableScan: orders";
+ \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
+ \n TableScan: person\
+ \n TableScan: orders";
+
quick_test(sql, expected);
}
@@ -6023,12 +6002,9 @@ mod tests {
ON person.id + 10 = orders.customer_id * 2";
let expected = "Projection: person.id, orders.order_id\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id + Int64(10) = orders.customer_id *
Int64(2)\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀, person.id
+ Int64(10)\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
orders.customer_id * Int64(2)\
- \n TableScan: orders";
+ \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6040,12 +6016,9 @@ mod tests {
ON person.id + person.age + 10 = orders.customer_id * 2 -
orders.price";
let expected = "Projection: person.id, orders.order_id\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id + person.age + Int64(10) =
orders.customer_id * Int64(2) - orders.price\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀, person.id
+ person.age + Int64(10)\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
orders.customer_id * Int64(2) - orders.price\
- \n TableScan: orders";
+ \n Inner Join: person.id + person.age + Int64(10) =
orders.customer_id * Int64(2) - orders.price\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6057,11 +6030,9 @@ mod tests {
ON person.id + person.age + 10 = orders.customer_id";
let expected = "Projection: person.id, orders.order_id\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id + person.age + Int64(10) =
orders.customer_id\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀, person.id
+ person.age + Int64(10)\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: person.id + person.age + Int64(10) =
orders.customer_id\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6073,11 +6044,9 @@ mod tests {
ON person.id = orders.customer_id * 2 - orders.price";
let expected = "Projection: person.id, orders.order_id\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id = orders.customer_id * Int64(2) -
orders.price\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
orders.customer_id * Int64(2) - orders.price\
- \n TableScan: orders";
+ \n Inner Join: person.id = orders.customer_id * Int64(2) -
orders.price\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6106,12 +6075,9 @@ mod tests {
ON orders.customer_id * 2 = person.id + 10";
let expected = "Projection: person.id, person.first_name,
person.last_name, person.age, person.state, person.salary, person.birth_date,
person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id + Int64(10) = orders.customer_id *
Int64(2)\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀, person.id
+ Int64(10)\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
orders.customer_id * Int64(2)\
- \n TableScan: orders";
+ \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6123,12 +6089,9 @@ mod tests {
ON orders.customer_id * 2 = person.id + 10";
let expected = "Projection: orders.customer_id * Int64(2), person.id +
Int64(10)\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id + Int64(10) = orders.customer_id *
Int64(2)\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀, person.id
+ Int64(10)\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
orders.customer_id * Int64(2)\
- \n TableScan: orders";
+ \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6156,12 +6119,9 @@ mod tests {
ON person.id * 2 = orders.customer_id + 10 and person.id * 2 =
orders.order_id";
let expected = "Projection: person.id, person.age\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id * Int64(2) = orders.customer_id +
Int64(10), person.id * Int64(2) = orders.order_id\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀, person.id
* Int64(2)\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
orders.customer_id + Int64(10)\
- \n TableScan: orders";
+ \n Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10),
person.id * Int64(2) = orders.order_id\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6174,11 +6134,9 @@ mod tests {
ON person.id * 2 = orders.customer_id + 10 and person.id =
orders.customer_id + 10";
let expected = "Projection: person.id, person.age\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\n Inner Join: person.id * Int64(2) =
orders.customer_id + Int64(10), person.id = orders.customer_id + Int64(10)\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀, person.id
* Int64(2)\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
orders.customer_id + Int64(10)\
- \n TableScan: orders";
+ \n Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10),
person.id = orders.customer_id + Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6596,12 +6554,9 @@ mod tests {
ON cast(person.id as Int) = cast(orders.customer_id as Int)";
let expected = "Projection: person.id, person.age\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id
AS Int32)\
- \n Projection: person.id, person.first_name, person.last_name,
person.age, person.state, person.salary, person.birth_date, person.😀,
CAST(person.id AS Int32) AS CAST(person.id AS Int32)\
- \n TableScan: person\
- \n Projection: orders.order_id, orders.customer_id,
orders.o_item_id, orders.qty, orders.price, orders.delivered,
CAST(orders.customer_id AS Int32) AS CAST(orders.customer_id AS Int32)\
- \n TableScan: orders";
+ \n Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id AS
Int32)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}