This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 734c211c3 Move the extract_join_keys to optimizer (#4711)
734c211c3 is described below
commit 734c211c3832b004cdb3cd57d1815c3fe006388a
Author: ygf11 <[email protected]>
AuthorDate: Tue Dec 27 18:00:13 2022 +0800
Move the extract_join_keys to optimizer (#4711)
* Move extract_join_keys to optimizer
* rename ExtractEquijoinExpr to ExtractEquijoinPredicate
* fix cargo clippy
* utilize the optimizer to traverse the plan tree
* add a new test
---
.../optimizer/src/extract_equijoin_predicate.rs | 420 +++++++++++++++++++++
datafusion/optimizer/src/lib.rs | 1 +
datafusion/optimizer/src/optimizer.rs | 2 +
datafusion/sql/src/planner.rs | 276 ++++----------
4 files changed, 487 insertions(+), 212 deletions(-)
diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs
b/datafusion/optimizer/src/extract_equijoin_predicate.rs
new file mode 100644
index 000000000..214fbe728
--- /dev/null
+++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs
@@ -0,0 +1,420 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Optimizer rule to extract equijoin expr from filter
+use crate::optimizer::ApplyOrder;
+use crate::{OptimizerConfig, OptimizerRule};
+use datafusion_common::DFSchema;
+use datafusion_common::Result;
+use datafusion_expr::utils::{can_hash, check_all_column_from_schema};
+use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan,
Operator};
+use std::sync::Arc;
+
+/// Optimization rule that extract equijoin expr from the filter
+#[derive(Default)]
+pub struct ExtractEquijoinPredicate;
+
+impl ExtractEquijoinPredicate {
+ #[allow(missing_docs)]
+ pub fn new() -> Self {
+ Self {}
+ }
+}
+
+impl OptimizerRule for ExtractEquijoinPredicate {
+ fn try_optimize(
+ &self,
+ plan: &LogicalPlan,
+ _config: &dyn OptimizerConfig,
+ ) -> Result<Option<LogicalPlan>> {
+ match plan {
+ LogicalPlan::Join(Join {
+ left,
+ right,
+ on,
+ filter,
+ join_type,
+ join_constraint,
+ schema,
+ null_equals_null,
+ }) => {
+ let left_schema = left.schema();
+ let right_schema = right.schema();
+
+ filter.as_ref().map_or(Result::Ok(None), |expr| {
+ let mut accum: Vec<(Expr, Expr)> = vec![];
+ let mut accum_filter: Vec<Expr> = vec![];
+ // TODO: avoding clone with split_conjunction
+ extract_join_keys(
+ expr.clone(),
+ &mut accum,
+ &mut accum_filter,
+ left_schema,
+ right_schema,
+ )?;
+
+ let optimized_plan = (!accum.is_empty()).then(|| {
+ let mut new_on = on.clone();
+ new_on.extend(accum);
+
+ let new_filter =
accum_filter.into_iter().reduce(Expr::and);
+ LogicalPlan::Join(Join {
+ left: left.clone(),
+ right: right.clone(),
+ on: new_on,
+ filter: new_filter,
+ join_type: *join_type,
+ join_constraint: *join_constraint,
+ schema: schema.clone(),
+ null_equals_null: *null_equals_null,
+ })
+ });
+
+ Ok(optimized_plan)
+ })
+ }
+ _ => Ok(None),
+ }
+ }
+
+ fn name(&self) -> &str {
+ "extract_equijoin_predicate"
+ }
+
+ fn apply_order(&self) -> Option<ApplyOrder> {
+ Some(ApplyOrder::BottomUp)
+ }
+}
+
+/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs
+/// Filters matching this pattern are added to `accum`
+/// Filters that don't match this pattern are added to `accum_filter`
+/// Examples:
+/// ```text
+/// foo = bar => accum=[(foo, bar)] accum_filter=[]
+/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[]
+/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1]
+///
+/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1,
c2):
+/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10]
+/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[]
+/// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10]
+/// ```
+fn extract_join_keys(
+ expr: Expr,
+ accum: &mut Vec<(Expr, Expr)>,
+ accum_filter: &mut Vec<Expr>,
+ left_schema: &Arc<DFSchema>,
+ right_schema: &Arc<DFSchema>,
+) -> Result<()> {
+ match &expr {
+ Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
+ Operator::Eq => {
+ let left = *left.clone();
+ let right = *right.clone();
+ let left_using_columns = left.to_columns()?;
+ let right_using_columns = right.to_columns()?;
+
+ // When one side key does not contain columns, we need move
this expression to filter.
+ // For example: a = 1, a = now() + 10.
+ if left_using_columns.is_empty() ||
right_using_columns.is_empty() {
+ accum_filter.push(expr);
+ return Ok(());
+ }
+
+ // Checking left join key is from left schema, right join key
is from right schema, or the opposite.
+ let l_is_left = check_all_column_from_schema(
+ &left_using_columns,
+ left_schema.clone(),
+ )?;
+ let r_is_right = check_all_column_from_schema(
+ &right_using_columns,
+ right_schema.clone(),
+ )?;
+
+ let r_is_left_and_l_is_right = || {
+ let result = check_all_column_from_schema(
+ &right_using_columns,
+ left_schema.clone(),
+ )? && check_all_column_from_schema(
+ &left_using_columns,
+ right_schema.clone(),
+ )?;
+
+ Result::Ok(result)
+ };
+
+ let join_key_pair = match (l_is_left, r_is_right) {
+ (true, true) => Some((left, right)),
+ (_, _) if r_is_left_and_l_is_right()? => Some((right,
left)),
+ _ => None,
+ };
+
+ if let Some((left_expr, right_expr)) = join_key_pair {
+ let left_expr_type = left_expr.get_type(left_schema)?;
+ let right_expr_type = right_expr.get_type(right_schema)?;
+
+ if can_hash(&left_expr_type) && can_hash(&right_expr_type)
{
+ accum.push((left_expr, right_expr));
+ } else {
+ accum_filter.push(expr);
+ }
+ } else {
+ accum_filter.push(expr);
+ }
+ }
+ Operator::And => {
+ if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) =
expr {
+ extract_join_keys(
+ *left,
+ accum,
+ accum_filter,
+ left_schema,
+ right_schema,
+ )?;
+ extract_join_keys(
+ *right,
+ accum,
+ accum_filter,
+ left_schema,
+ right_schema,
+ )?;
+ }
+ }
+ _other => {
+ accum_filter.push(expr);
+ }
+ },
+ _other => {
+ accum_filter.push(expr);
+ }
+ }
+
+ Ok(())
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test::*;
+ use datafusion_common::Column;
+ use datafusion_expr::{
+ col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
+ };
+
+ fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> {
+ assert_optimized_plan_eq_display_indent(
+ Arc::new(ExtractEquijoinPredicate {}),
+ plan,
+ expected,
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn join_with_only_column_equi_predicate() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ let plan = LogicalPlanBuilder::from(t1)
+ .join(
+ t2,
+ JoinType::Left,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(col("t1.a").eq(col("t2.a"))),
+ )?
+ .build()?;
+ let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32,
a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn join_with_only_equi_expr_predicate() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ let plan = LogicalPlanBuilder::from(t1)
+ .join(
+ t2,
+ JoinType::Left,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
+ )?
+ .build()?;
+ let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2)
[a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn join_with_only_none_equi_predicate() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ let plan = LogicalPlanBuilder::from(t1)
+ .join(
+ t2,
+ JoinType::Left,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(
+ (col("t1.a") + lit(10i64))
+ .gt_eq(col("t2.a") * lit(2u32))
+ .and(col("t1.b").lt(lit(100i32))),
+ ),
+ )?
+ .build()?;
+ let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a *
UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32,
b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn join_with_expr_both_from_filter_and_keys() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ let plan = LogicalPlanBuilder::from(t1)
+ .join_with_expr_keys(
+ t2,
+ JoinType::Left,
+ (
+ vec![col("t1.a") + lit(11u32)],
+ vec![col("t2.a") * lit(2u32)],
+ ),
+ Some(
+ (col("t1.a") + lit(10i64))
+ .eq(col("t2.a") * lit(2u32))
+ .and(col("t1.b").lt(lit(100i32))),
+ ),
+ )?
+ .build()?;
+ let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a
+ Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32,
c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn join_with_and_or_filter() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+
+ let plan = LogicalPlanBuilder::from(t1)
+ .join(
+ t2,
+ JoinType::Left,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(
+ col("t1.c")
+ .eq(col("t2.c"))
+ .or((col("t1.a") + col("t1.b")).gt(col("t2.b") +
col("t2.c")))
+ .and(
+
col("t1.a").eq(col("t2.a")).and(col("t1.b").eq(col("t2.b"))),
+ ),
+ ),
+ )?
+ .build()?;
+ let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c =
t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32,
b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn join_with_multiple_table() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+ let t3 = test_table_scan_with_name("t3")?;
+
+ let input = LogicalPlanBuilder::from(t2)
+ .join(
+ t3,
+ JoinType::Left,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(
+ col("t2.a")
+ .eq(col("t3.a"))
+ .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
+ ),
+ )?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(t1)
+ .join(
+ input,
+ JoinType::Left,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(
+ col("t1.a")
+ .eq(col("t2.a"))
+ .and((col("t1.c") + col("t2.c") +
col("t3.c")).lt(lit(100u32))),
+ ),
+ )?
+ .build()?;
+ let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c <
UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32,
a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100)
[a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_plan_eq(&plan, expected)
+ }
+
+ #[test]
+ fn join_with_multiple_table_and_eq_filter() -> Result<()> {
+ let t1 = test_table_scan_with_name("t1")?;
+ let t2 = test_table_scan_with_name("t2")?;
+ let t3 = test_table_scan_with_name("t3")?;
+
+ let input = LogicalPlanBuilder::from(t2)
+ .join(
+ t3,
+ JoinType::Left,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(
+ col("t2.a")
+ .eq(col("t3.a"))
+ .and((col("t2.a") + col("t3.b")).gt(lit(100u32))),
+ ),
+ )?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(t1)
+ .join(
+ input,
+ JoinType::Left,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+
Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
+ )?
+ .build()?;
+ let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32,
b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
+ \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100)
[a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
+ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
+
+ assert_plan_eq(&plan, expected)
+ }
+}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index a4804ca5b..b03725fe6 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -22,6 +22,7 @@ pub mod eliminate_cross_join;
pub mod eliminate_filter;
pub mod eliminate_limit;
pub mod eliminate_outer_join;
+pub mod extract_equijoin_predicate;
pub mod filter_null_join_keys;
pub mod inline_table_scan;
pub mod optimizer;
diff --git a/datafusion/optimizer/src/optimizer.rs
b/datafusion/optimizer/src/optimizer.rs
index 6fe94e792..36968f2f1 100644
--- a/datafusion/optimizer/src/optimizer.rs
+++ b/datafusion/optimizer/src/optimizer.rs
@@ -24,6 +24,7 @@ use crate::eliminate_cross_join::EliminateCrossJoin;
use crate::eliminate_filter::EliminateFilter;
use crate::eliminate_limit::EliminateLimit;
use crate::eliminate_outer_join::EliminateOuterJoin;
+use crate::extract_equijoin_predicate::ExtractEquijoinPredicate;
use crate::filter_null_join_keys::FilterNullJoinKeys;
use crate::inline_table_scan::InlineTableScan;
use crate::propagate_empty_relation::PropagateEmptyRelation;
@@ -237,6 +238,7 @@ impl Optimizer {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
Arc::new(InlineTableScan::new()),
Arc::new(TypeCoercion::new()),
+ Arc::new(ExtractEquijoinPredicate::new()),
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(DecorrelateWhereExists::new()),
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index dc43cbaf1..a4cc0b775 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -58,9 +58,8 @@ use datafusion_expr::logical_plan::{
};
use datafusion_expr::logical_plan::{Filter, Prepare, Subquery};
use datafusion_expr::utils::{
- can_hash, check_all_column_from_schema, expand_qualified_wildcard,
expand_wildcard,
- expr_as_column_expr, expr_to_columns, find_aggregate_exprs,
find_column_exprs,
- find_window_exprs, COUNT_STAR_EXPANSION,
+ expand_qualified_wildcard, expand_wildcard, expr_as_column_expr,
expr_to_columns,
+ find_aggregate_exprs, find_column_exprs, find_window_exprs,
COUNT_STAR_EXPANSION,
};
use datafusion_expr::Expr::Alias;
use datafusion_expr::{
@@ -806,7 +805,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<LogicalPlan> {
match constraint {
JoinConstraint::On(sql_expr) => {
- let mut keys: Vec<(Expr, Expr)> = vec![];
let join_schema = left.schema().join(right.schema())?;
// parse ON expression
@@ -820,45 +818,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// normalize all columns in expression
let using_columns = expr.to_columns()?;
- let normalized_expr = normalize_col_with_schemas(
+ let filter = normalize_col_with_schemas(
expr,
&[left.schema(), right.schema()],
&[using_columns],
)?;
- // expression that didn't match equi-join pattern
- let mut filter = vec![];
-
- // extract join keys
- extract_join_keys(
- normalized_expr,
- &mut keys,
- &mut filter,
- left.schema(),
- right.schema(),
- )?;
-
- let (left_keys, right_keys): (Vec<Expr>, Vec<Expr>) =
- keys.into_iter().unzip();
-
- let join_filter = filter.into_iter().reduce(Expr::and);
-
- if left_keys.is_empty() && join_filter.is_none() {
- let mut join =
LogicalPlanBuilder::from(left).cross_join(right)?;
- if let Some(filter) = join_filter {
- join = join.filter(filter)?;
- }
- join.build()
- } else {
- LogicalPlanBuilder::from(left)
- .join_with_expr_keys(
- right,
- join_type,
- (left_keys, right_keys),
- join_filter,
- )?
- .build()
- }
+ LogicalPlanBuilder::from(left)
+ .join(
+ right,
+ join_type,
+ (Vec::<Column>::new(), Vec::<Column>::new()),
+ Some(filter),
+ )?
+ .build()
}
JoinConstraint::Using(idents) => {
let keys: Vec<Column> = idents
@@ -3095,113 +3068,6 @@ pub fn object_name_to_qualifier(sql_table_name:
&ObjectName) -> String {
.join(" AND ")
}
-/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs
-/// Filters matching this pattern are added to `accum`
-/// Filters that don't match this pattern are added to `accum_filter`
-/// Examples:
-/// ```text
-/// foo = bar => accum=[(foo, bar)] accum_filter=[]
-/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[]
-/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1]
-///
-/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1,
c2):
-/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10]
-/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[]
-/// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10]
-/// ```
-fn extract_join_keys(
- expr: Expr,
- accum: &mut Vec<(Expr, Expr)>,
- accum_filter: &mut Vec<Expr>,
- left_schema: &Arc<DFSchema>,
- right_schema: &Arc<DFSchema>,
-) -> Result<()> {
- match &expr {
- Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
- Operator::Eq => {
- let left = *left.clone();
- let right = *right.clone();
- let left_using_columns = left.to_columns()?;
- let right_using_columns = right.to_columns()?;
-
- // When one side key does not contain columns, we need move
this expression to filter.
- // For example: a = 1, a = now() + 10.
- if left_using_columns.is_empty() ||
right_using_columns.is_empty() {
- accum_filter.push(expr);
- return Ok(());
- }
-
- // Checking left join key is from left schema, right join key
is from right schema, or the opposite.
- let l_is_left = check_all_column_from_schema(
- &left_using_columns,
- left_schema.clone(),
- )?;
- let r_is_right = check_all_column_from_schema(
- &right_using_columns,
- right_schema.clone(),
- )?;
-
- let r_is_left_and_l_is_right = || {
- let result = check_all_column_from_schema(
- &right_using_columns,
- left_schema.clone(),
- )? && check_all_column_from_schema(
- &left_using_columns,
- right_schema.clone(),
- )?;
-
- Result::Ok(result)
- };
-
- let join_key_pair = match (l_is_left, r_is_right) {
- (true, true) => Some((left, right)),
- (_, _) if r_is_left_and_l_is_right()? => Some((right,
left)),
- _ => None,
- };
-
- if let Some((left_expr, right_expr)) = join_key_pair {
- let left_expr_type = left_expr.get_type(left_schema)?;
- let right_expr_type = right_expr.get_type(right_schema)?;
-
- if can_hash(&left_expr_type) && can_hash(&right_expr_type)
{
- accum.push((left_expr, right_expr));
- } else {
- accum_filter.push(expr);
- }
- } else {
- accum_filter.push(expr);
- }
- }
- Operator::And => {
- if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) =
expr {
- extract_join_keys(
- *left,
- accum,
- accum_filter,
- left_schema,
- right_schema,
- )?;
- extract_join_keys(
- *right,
- accum,
- accum_filter,
- left_schema,
- right_schema,
- )?;
- }
- }
- _other => {
- accum_filter.push(expr);
- }
- },
- _other => {
- accum_filter.push(expr);
- }
- }
-
- Ok(())
-}
-
/// Ensure any column reference of the expression is unambiguous.
/// Assume we have two schema:
/// schema1: a, b ,c
@@ -4620,9 +4486,9 @@ mod tests {
JOIN orders \
ON id = customer_id";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id = orders.customer_id\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id = orders.customer_id\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -4633,7 +4499,7 @@ mod tests {
JOIN orders \
ON id = customer_id AND order_id > 1 ";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id = orders.customer_id Filter:
orders.order_id > Int64(1)\
+ \n Inner Join: Filter: person.id = orders.customer_id AND
orders.order_id > Int64(1)\
\n TableScan: person\
\n TableScan: orders";
@@ -4647,7 +4513,7 @@ mod tests {
LEFT JOIN orders \
ON id = customer_id AND order_id > 1 AND age < 30";
let expected = "Projection: person.id, orders.order_id\
- \n Left Join: person.id = orders.customer_id Filter:
orders.order_id > Int64(1) AND person.age < Int64(30)\
+ \n Left Join: Filter: person.id = orders.customer_id AND
orders.order_id > Int64(1) AND person.age < Int64(30)\
\n TableScan: person\
\n TableScan: orders";
quick_test(sql, expected);
@@ -4659,8 +4525,9 @@ mod tests {
FROM person \
RIGHT JOIN orders \
ON id = customer_id AND id > 1 AND order_id < 100";
+
let expected = "Projection: person.id, orders.order_id\
- \n Right Join: person.id = orders.customer_id Filter: person.id >
Int64(1) AND orders.order_id < Int64(100)\
+ \n Right Join: Filter: person.id = orders.customer_id AND
person.id > Int64(1) AND orders.order_id < Int64(100)\
\n TableScan: person\
\n TableScan: orders";
quick_test(sql, expected);
@@ -4673,9 +4540,9 @@ mod tests {
FULL JOIN orders \
ON id = customer_id AND id > 1 AND order_id < 100";
let expected = "Projection: person.id, orders.order_id\
- \n Full Join: person.id = orders.customer_id Filter: person.id >
Int64(1) AND orders.order_id < Int64(100)\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Full Join: Filter: person.id = orders.customer_id AND
person.id > Int64(1) AND orders.order_id < Int64(100)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -4686,9 +4553,9 @@ mod tests {
JOIN orders \
ON person.id = orders.customer_id";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id = orders.customer_id\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id = orders.customer_id\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -4727,8 +4594,8 @@ mod tests {
JOIN orders ON id = customer_id \
JOIN lineitem ON o_item_id = l_item_id";
let expected = "Projection: person.id, orders.order_id,
lineitem.l_description\
- \n Inner Join: orders.o_item_id = lineitem.l_item_id\
- \n Inner Join: person.id = orders.customer_id\
+ \n Inner Join: Filter: orders.o_item_id = lineitem.l_item_id\
+ \n Inner Join: Filter: person.id = orders.customer_id\
\n TableScan: person\
\n TableScan: orders\
\n TableScan: lineitem";
@@ -5517,11 +5384,11 @@ mod tests {
fn join_with_aliases() {
let sql = "select peeps.id, folks.first_name from person as peeps join
person as folks on peeps.id = folks.id";
let expected = "Projection: peeps.id, folks.first_name\
- \n Inner Join: peeps.id = folks.id\
- \n SubqueryAlias: peeps\
- \n TableScan: person\
- \n SubqueryAlias: folks\
- \n TableScan: person";
+ \n Inner Join: Filter: peeps.id = folks.id\
+ \n SubqueryAlias: peeps\
+ \n TableScan: person\
+ \n SubqueryAlias: folks\
+ \n TableScan: person";
quick_test(sql, expected);
}
@@ -5855,7 +5722,7 @@ mod tests {
FROM person \
JOIN orders ON id = customer_id AND (person.age > 30 OR
person.last_name = 'X')";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id = orders.customer_id Filter: person.age
> Int64(30) OR person.last_name = Utf8(\"X\")\
+ \n Inner Join: Filter: person.id = orders.customer_id AND
(person.age > Int64(30) OR person.last_name = Utf8(\"X\"))\
\n TableScan: person\
\n TableScan: orders";
quick_test(sql, expected);
@@ -5981,9 +5848,9 @@ mod tests {
ON orders.customer_id * 2 = person.id + 10";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: orders.customer_id * Int64(2) = person.id
+ Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -5996,9 +5863,9 @@ mod tests {
ON person.id + 10 = orders.customer_id * 2";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id + Int64(10) =
orders.customer_id * Int64(2)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6010,37 +5877,37 @@ mod tests {
ON person.id + person.age + 10 = orders.customer_id * 2 -
orders.price";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id + person.age + Int64(10) =
orders.customer_id * Int64(2) - orders.price\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id + person.age + Int64(10) =
orders.customer_id * Int64(2) - orders.price\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
#[test]
- fn test_left_projection_expr_eq_join() {
+ fn test_left_expr_eq_join() {
let sql = "SELECT id, order_id \
FROM person \
INNER JOIN orders \
ON person.id + person.age + 10 = orders.customer_id";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id + person.age + Int64(10) =
orders.customer_id\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id + person.age + Int64(10) =
orders.customer_id\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
#[test]
- fn test_right_projection_expr_eq_join() {
+ fn test_right_expr_eq_join() {
let sql = "SELECT id, order_id \
FROM person \
INNER JOIN orders \
ON person.id = orders.customer_id * 2 - orders.price";
let expected = "Projection: person.id, orders.order_id\
- \n Inner Join: person.id = orders.customer_id * Int64(2) -
orders.price\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id = orders.customer_id * Int64(2)
- orders.price\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6108,9 +5975,9 @@ mod tests {
ON orders.customer_id * 2 = person.id + 10";
let expected = "Projection: person.id, person.first_name,
person.last_name, person.age, person.state, person.salary, person.birth_date,
person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty,
orders.price, orders.delivered\
- \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: orders.customer_id * Int64(2) = person.id
+ Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6122,24 +5989,9 @@ mod tests {
ON orders.customer_id * 2 = person.id + 10";
let expected = "Projection: orders.customer_id * Int64(2), person.id +
Int64(10)\
- \n Inner Join: person.id + Int64(10) = orders.customer_id * Int64(2)\
- \n TableScan: person\
- \n TableScan: orders";
- quick_test(sql, expected);
- }
-
- #[test]
- fn test_non_projetion_after_inner_join() {
- // There's no need to add projection for left and right, so does
adding projection after join.
- let sql = "SELECT person.id, person.age
- FROM person
- INNER JOIN orders
- ON orders.customer_id = person.id";
-
- let expected = "Projection: person.id, person.age\
- \n Inner Join: person.id = orders.customer_id\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: orders.customer_id * Int64(2) = person.id
+ Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6152,9 +6004,9 @@ mod tests {
ON person.id * 2 = orders.customer_id + 10 and person.id * 2 =
orders.order_id";
let expected = "Projection: person.id, person.age\
- \n Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10),
person.id * Int64(2) = orders.order_id\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id * Int64(2) = orders.customer_id
+ Int64(10) AND person.id * Int64(2) = orders.order_id\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6167,9 +6019,9 @@ mod tests {
ON person.id * 2 = orders.customer_id + 10 and person.id =
orders.customer_id + 10";
let expected = "Projection: person.id, person.age\
- \n Inner Join: person.id * Int64(2) = orders.customer_id + Int64(10),
person.id = orders.customer_id + Int64(10)\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: person.id * Int64(2) = orders.customer_id
+ Int64(10) AND person.id = orders.customer_id + Int64(10)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}
@@ -6587,9 +6439,9 @@ mod tests {
ON cast(person.id as Int) = cast(orders.customer_id as Int)";
let expected = "Projection: person.id, person.age\
- \n Inner Join: CAST(person.id AS Int32) = CAST(orders.customer_id AS
Int32)\
- \n TableScan: person\
- \n TableScan: orders";
+ \n Inner Join: Filter: CAST(person.id AS Int32) =
CAST(orders.customer_id AS Int32)\
+ \n TableScan: person\
+ \n TableScan: orders";
quick_test(sql, expected);
}