This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 18c581c fix join column handling logic for `On` and `Using`
constraints (#605)
18c581c is described below
commit 18c581c4dbfbc3b5d135b3bc0d1cdb5c16af9c78
Author: QP Hou <[email protected]>
AuthorDate: Wed Jul 7 05:06:12 2021 -0700
fix join column handling logic for `On` and `Using` constraints (#605)
* fix join column handling logic for `On` and `Using` constraints
* handling join column expansion during USING JOIN planning
get rid of shared field and move column expansion logic into plan
builder and optimizer.
* add more comments & fix clippy
* add more comment
* reduce duplicate code in join predicate pushdown
---
ballista/rust/core/proto/ballista.proto | 10 +-
.../rust/core/src/serde/logical_plan/from_proto.rs | 41 ++--
.../rust/core/src/serde/logical_plan/to_proto.rs | 15 +-
ballista/rust/core/src/serde/mod.rs | 46 ++++-
.../core/src/serde/physical_plan/from_proto.rs | 16 +-
ballista/rust/core/src/serde/physical_plan/mod.rs | 3 +-
.../rust/core/src/serde/physical_plan/to_proto.rs | 13 +-
benchmarks/queries/q7.sql | 2 +-
datafusion/src/execution/context.rs | 90 ++++++++
datafusion/src/execution/dataframe_impl.rs | 23 ++-
datafusion/src/logical_plan/builder.rs | 94 ++-------
datafusion/src/logical_plan/dfschema.rs | 149 +++++++-------
datafusion/src/logical_plan/expr.rs | 94 ++++++---
datafusion/src/logical_plan/mod.rs | 8 +-
datafusion/src/logical_plan/plan.rs | 54 ++++-
datafusion/src/optimizer/filter_push_down.rs | 226 ++++++++++++++++-----
datafusion/src/optimizer/projection_push_down.rs | 88 +++++++-
datafusion/src/optimizer/utils.rs | 9 +-
datafusion/src/physical_plan/hash_join.rs | 198 +++++++++---------
datafusion/src/physical_plan/hash_utils.rs | 57 +-----
datafusion/src/physical_plan/planner.rs | 13 +-
datafusion/src/sql/planner.rs | 76 ++++---
datafusion/src/test/mod.rs | 11 +-
23 files changed, 836 insertions(+), 500 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto
b/ballista/rust/core/proto/ballista.proto
index e378806..4696d21 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -378,12 +378,18 @@ enum JoinType {
ANTI = 5;
}
+enum JoinConstraint {
+ ON = 0;
+ USING = 1;
+}
+
message JoinNode {
LogicalPlanNode left = 1;
LogicalPlanNode right = 2;
JoinType join_type = 3;
- repeated Column left_join_column = 4;
- repeated Column right_join_column = 5;
+ JoinConstraint join_constraint = 4;
+ repeated Column left_join_column = 5;
+ repeated Column right_join_column = 6;
}
message LimitNode {
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index a1136cf..cad0543 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -26,8 +26,8 @@ use datafusion::logical_plan::window_frames::{
};
use datafusion::logical_plan::{
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round,
signum, sin,
- sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan,
- LogicalPlanBuilder, Operator,
+ sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinConstraint,
JoinType,
+ LogicalPlan, LogicalPlanBuilder, Operator,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::csv::CsvReadOptions;
@@ -257,23 +257,32 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
join.join_type
))
})?;
- let join_type = match join_type {
- protobuf::JoinType::Inner => JoinType::Inner,
- protobuf::JoinType::Left => JoinType::Left,
- protobuf::JoinType::Right => JoinType::Right,
- protobuf::JoinType::Full => JoinType::Full,
- protobuf::JoinType::Semi => JoinType::Semi,
- protobuf::JoinType::Anti => JoinType::Anti,
- };
- LogicalPlanBuilder::from(convert_box_required!(join.left)?)
- .join(
+ let join_constraint = protobuf::JoinConstraint::from_i32(
+ join.join_constraint,
+ )
+ .ok_or_else(|| {
+ proto_error(format!(
+ "Received a JoinNode message with unknown
JoinConstraint {}",
+ join.join_constraint
+ ))
+ })?;
+
+ let builder =
LogicalPlanBuilder::from(convert_box_required!(join.left)?);
+ let builder = match join_constraint.into() {
+ JoinConstraint::On => builder.join(
&convert_box_required!(join.right)?,
- join_type,
+ join_type.into(),
left_keys,
right_keys,
- )?
- .build()
- .map_err(|e| e.into())
+ )?,
+ JoinConstraint::Using => builder.join_using(
+ &convert_box_required!(join.right)?,
+ join_type.into(),
+ left_keys,
+ )?,
+ };
+
+ builder.build().map_err(|e| e.into())
}
}
}
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index 4049622..07d7a59 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field,
IntervalUnit, Schema, TimeUn
use datafusion::datasource::CsvFile;
use datafusion::logical_plan::{
window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits},
- Column, Expr, JoinType, LogicalPlan,
+ Column, Expr, JoinConstraint, JoinType, LogicalPlan,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::functions::BuiltinScalarFunction;
@@ -804,26 +804,23 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
right,
on,
join_type,
+ join_constraint,
..
} => {
let left: protobuf::LogicalPlanNode =
left.as_ref().try_into()?;
let right: protobuf::LogicalPlanNode =
right.as_ref().try_into()?;
- let join_type = match join_type {
- JoinType::Inner => protobuf::JoinType::Inner,
- JoinType::Left => protobuf::JoinType::Left,
- JoinType::Right => protobuf::JoinType::Right,
- JoinType::Full => protobuf::JoinType::Full,
- JoinType::Semi => protobuf::JoinType::Semi,
- JoinType::Anti => protobuf::JoinType::Anti,
- };
let (left_join_column, right_join_column) =
on.iter().map(|(l, r)| (l.into(), r.into())).unzip();
+ let join_type: protobuf::JoinType =
join_type.to_owned().into();
+ let join_constraint: protobuf::JoinConstraint =
+ join_constraint.to_owned().into();
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Join(Box::new(
protobuf::JoinNode {
left: Some(Box::new(left)),
right: Some(Box::new(right)),
join_type: join_type.into(),
+ join_constraint: join_constraint.into(),
left_join_column,
right_join_column,
},
diff --git a/ballista/rust/core/src/serde/mod.rs
b/ballista/rust/core/src/serde/mod.rs
index af83660..1df0675 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -20,7 +20,7 @@
use std::{convert::TryInto, io::Cursor};
-use datafusion::logical_plan::Operator;
+use datafusion::logical_plan::{JoinConstraint, JoinType, Operator};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
@@ -291,3 +291,47 @@ impl Into<datafusion::arrow::datatypes::DataType> for
protobuf::PrimitiveScalarT
}
}
}
+
+impl From<protobuf::JoinType> for JoinType {
+ fn from(t: protobuf::JoinType) -> Self {
+ match t {
+ protobuf::JoinType::Inner => JoinType::Inner,
+ protobuf::JoinType::Left => JoinType::Left,
+ protobuf::JoinType::Right => JoinType::Right,
+ protobuf::JoinType::Full => JoinType::Full,
+ protobuf::JoinType::Semi => JoinType::Semi,
+ protobuf::JoinType::Anti => JoinType::Anti,
+ }
+ }
+}
+
+impl From<JoinType> for protobuf::JoinType {
+ fn from(t: JoinType) -> Self {
+ match t {
+ JoinType::Inner => protobuf::JoinType::Inner,
+ JoinType::Left => protobuf::JoinType::Left,
+ JoinType::Right => protobuf::JoinType::Right,
+ JoinType::Full => protobuf::JoinType::Full,
+ JoinType::Semi => protobuf::JoinType::Semi,
+ JoinType::Anti => protobuf::JoinType::Anti,
+ }
+ }
+}
+
+impl From<protobuf::JoinConstraint> for JoinConstraint {
+ fn from(t: protobuf::JoinConstraint) -> Self {
+ match t {
+ protobuf::JoinConstraint::On => JoinConstraint::On,
+ protobuf::JoinConstraint::Using => JoinConstraint::Using,
+ }
+ }
+}
+
+impl From<JoinConstraint> for protobuf::JoinConstraint {
+ fn from(t: JoinConstraint) -> Self {
+ match t {
+ JoinConstraint::On => protobuf::JoinConstraint::On,
+ JoinConstraint::Using => protobuf::JoinConstraint::Using,
+ }
+ }
+}
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index 717ee20..12c1743 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -35,7 +35,9 @@ use datafusion::catalog::catalog::{
use datafusion::execution::context::{
ExecutionConfig, ExecutionContextState, ExecutionProps,
};
-use datafusion::logical_plan::{window_frames::WindowFrame, DFSchema, Expr};
+use datafusion::logical_plan::{
+ window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType,
+};
use datafusion::physical_plan::aggregates::{create_aggregate_expr,
AggregateFunction};
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::hash_aggregate::{AggregateMode,
HashAggregateExec};
@@ -57,7 +59,6 @@ use datafusion::physical_plan::{
filter::FilterExec,
functions::{self, BuiltinScalarFunction, ScalarFunctionExpr},
hash_join::HashJoinExec,
- hash_utils::JoinType,
limit::{GlobalLimitExec, LocalLimitExec},
parquet::ParquetExec,
projection::ProjectionExec,
@@ -348,14 +349,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for
&protobuf::PhysicalPlanNode {
hashjoin.join_type
))
})?;
- let join_type = match join_type {
- protobuf::JoinType::Inner => JoinType::Inner,
- protobuf::JoinType::Left => JoinType::Left,
- protobuf::JoinType::Right => JoinType::Right,
- protobuf::JoinType::Full => JoinType::Full,
- protobuf::JoinType::Semi => JoinType::Semi,
- protobuf::JoinType::Anti => JoinType::Anti,
- };
+
let partition_mode =
protobuf::PartitionMode::from_i32(hashjoin.partition_mode)
.ok_or_else(|| {
@@ -372,7 +366,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for
&protobuf::PhysicalPlanNode {
left,
right,
on,
- &join_type,
+ &join_type.into(),
partition_mode,
)?))
}
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs
b/ballista/rust/core/src/serde/physical_plan/mod.rs
index a393d7f..3bf7e9c 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -27,7 +27,7 @@ mod roundtrip_tests {
compute::kernels::sort::SortOptions,
datatypes::{DataType, Field, Schema},
},
- logical_plan::Operator,
+ logical_plan::{JoinType, Operator},
physical_plan::{
empty::EmptyExec,
expressions::{binary, col, lit, InListExpr, NotExpr},
@@ -35,7 +35,6 @@ mod roundtrip_tests {
filter::FilterExec,
hash_aggregate::{AggregateMode, HashAggregateExec},
hash_join::{HashJoinExec, PartitionMode},
- hash_utils::JoinType,
limit::{GlobalLimitExec, LocalLimitExec},
sort::SortExec,
AggregateExpr, ColumnarValue, Distribution, ExecutionPlan,
Partitioning,
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 0fc2785..875dbf2 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -26,6 +26,7 @@ use std::{
sync::Arc,
};
+use datafusion::logical_plan::JoinType;
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::csv::CsvExec;
use datafusion::physical_plan::expressions::{
@@ -35,7 +36,6 @@ use datafusion::physical_plan::expressions::{CastExpr,
TryCastExpr};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::hash_aggregate::AggregateMode;
use datafusion::physical_plan::hash_join::{HashJoinExec, PartitionMode};
-use datafusion::physical_plan::hash_utils::JoinType;
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::parquet::ParquetExec;
use datafusion::physical_plan::projection::ProjectionExec;
@@ -135,18 +135,13 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn
ExecutionPlan> {
}),
})
.collect();
- let join_type = match exec.join_type() {
- JoinType::Inner => protobuf::JoinType::Inner,
- JoinType::Left => protobuf::JoinType::Left,
- JoinType::Right => protobuf::JoinType::Right,
- JoinType::Full => protobuf::JoinType::Full,
- JoinType::Semi => protobuf::JoinType::Semi,
- JoinType::Anti => protobuf::JoinType::Anti,
- };
+ let join_type: protobuf::JoinType =
exec.join_type().to_owned().into();
+
let partition_mode = match exec.partition_mode() {
PartitionMode::CollectLeft =>
protobuf::PartitionMode::CollectLeft,
PartitionMode::Partitioned =>
protobuf::PartitionMode::Partitioned,
};
+
Ok(protobuf::PhysicalPlanNode {
physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new(
protobuf::HashJoinExecNode {
diff --git a/benchmarks/queries/q7.sql b/benchmarks/queries/q7.sql
index d53877c..512e5be 100644
--- a/benchmarks/queries/q7.sql
+++ b/benchmarks/queries/q7.sql
@@ -36,4 +36,4 @@ group by
order by
supp_nation,
cust_nation,
- l_year;
\ No newline at end of file
+ l_year;
diff --git a/datafusion/src/execution/context.rs
b/datafusion/src/execution/context.rs
index 6a26e04..d2dcec5 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1279,6 +1279,96 @@ mod tests {
}
#[tokio::test]
+ async fn left_join_using() -> Result<()> {
+ let results = execute(
+ "SELECT t1.c1, t2.c2 FROM test t1 JOIN test t2 USING (c2) ORDER BY
t2.c2",
+ 1,
+ )
+ .await?;
+ assert_eq!(results.len(), 1);
+
+ let expected = vec![
+ "+----+----+",
+ "| c1 | c2 |",
+ "+----+----+",
+ "| 0 | 1 |",
+ "| 0 | 2 |",
+ "| 0 | 3 |",
+ "| 0 | 4 |",
+ "| 0 | 5 |",
+ "| 0 | 6 |",
+ "| 0 | 7 |",
+ "| 0 | 8 |",
+ "| 0 | 9 |",
+ "| 0 | 10 |",
+ "+----+----+",
+ ];
+
+ assert_batches_eq!(expected, &results);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn left_join_using_join_key_projection() -> Result<()> {
+ let results = execute(
+ "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 USING (c2)
ORDER BY t2.c2",
+ 1,
+ )
+ .await?;
+ assert_eq!(results.len(), 1);
+
+ let expected = vec![
+ "+----+----+----+",
+ "| c1 | c2 | c2 |",
+ "+----+----+----+",
+ "| 0 | 1 | 1 |",
+ "| 0 | 2 | 2 |",
+ "| 0 | 3 | 3 |",
+ "| 0 | 4 | 4 |",
+ "| 0 | 5 | 5 |",
+ "| 0 | 6 | 6 |",
+ "| 0 | 7 | 7 |",
+ "| 0 | 8 | 8 |",
+ "| 0 | 9 | 9 |",
+ "| 0 | 10 | 10 |",
+ "+----+----+----+",
+ ];
+
+ assert_batches_eq!(expected, &results);
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn left_join() -> Result<()> {
+ let results = execute(
+ "SELECT t1.c1, t1.c2, t2.c2 FROM test t1 JOIN test t2 ON t1.c2 =
t2.c2 ORDER BY t1.c2",
+ 1,
+ )
+ .await?;
+ assert_eq!(results.len(), 1);
+
+ let expected = vec![
+ "+----+----+----+",
+ "| c1 | c2 | c2 |",
+ "+----+----+----+",
+ "| 0 | 1 | 1 |",
+ "| 0 | 2 | 2 |",
+ "| 0 | 3 | 3 |",
+ "| 0 | 4 | 4 |",
+ "| 0 | 5 | 5 |",
+ "| 0 | 6 | 6 |",
+ "| 0 | 7 | 7 |",
+ "| 0 | 8 | 8 |",
+ "| 0 | 9 | 9 |",
+ "| 0 | 10 | 10 |",
+ "+----+----+----+",
+ ];
+
+ assert_batches_eq!(expected, &results);
+ Ok(())
+ }
+
+ #[tokio::test]
async fn window() -> Result<()> {
let results = execute(
"SELECT \
diff --git a/datafusion/src/execution/dataframe_impl.rs
b/datafusion/src/execution/dataframe_impl.rs
index 7cf7797..4edd01c 100644
--- a/datafusion/src/execution/dataframe_impl.rs
+++ b/datafusion/src/execution/dataframe_impl.rs
@@ -264,7 +264,7 @@ mod tests {
#[tokio::test]
async fn join() -> Result<()> {
let left = test_table()?.select_columns(&["c1", "c2"])?;
- let right = test_table()?.select_columns(&["c1", "c3"])?;
+ let right = test_table_with_name("c2")?.select_columns(&["c1", "c3"])?;
let left_rows = left.collect().await?;
let right_rows = right.collect().await?;
let join = left.join(right, JoinType::Inner, &["c1"], &["c1"])?;
@@ -315,7 +315,7 @@ mod tests {
#[test]
fn registry() -> Result<()> {
let mut ctx = ExecutionContext::new();
- register_aggregate_csv(&mut ctx)?;
+ register_aggregate_csv(&mut ctx, "aggregate_test_100")?;
// declare the udf
let my_fn: ScalarFunctionImplementation =
@@ -366,21 +366,28 @@ mod tests {
/// Create a logical plan from a SQL query
fn create_plan(sql: &str) -> Result<LogicalPlan> {
let mut ctx = ExecutionContext::new();
- register_aggregate_csv(&mut ctx)?;
+ register_aggregate_csv(&mut ctx, "aggregate_test_100")?;
ctx.create_logical_plan(sql)
}
- fn test_table() -> Result<Arc<dyn DataFrame + 'static>> {
+ fn test_table_with_name(name: &str) -> Result<Arc<dyn DataFrame +
'static>> {
let mut ctx = ExecutionContext::new();
- register_aggregate_csv(&mut ctx)?;
- ctx.table("aggregate_test_100")
+ register_aggregate_csv(&mut ctx, name)?;
+ ctx.table(name)
+ }
+
+ fn test_table() -> Result<Arc<dyn DataFrame + 'static>> {
+ test_table_with_name("aggregate_test_100")
}
- fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
+ fn register_aggregate_csv(
+ ctx: &mut ExecutionContext,
+ table_name: &str,
+ ) -> Result<()> {
let schema = test::aggr_test_schema();
let testdata = crate::test_util::arrow_test_data();
ctx.register_csv(
- "aggregate_test_100",
+ table_name,
&format!("{}/csv/aggregate_test_100.csv", testdata),
CsvReadOptions::new().schema(schema.as_ref()),
)?;
diff --git a/datafusion/src/logical_plan/builder.rs
b/datafusion/src/logical_plan/builder.rs
index 1a53e21..41f29c4 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -40,7 +40,6 @@ use crate::logical_plan::{
columnize_expr, normalize_col, normalize_cols, Column, DFField, DFSchema,
DFSchemaRef, Partitioning,
};
-use std::collections::HashSet;
/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";
@@ -217,7 +216,6 @@ impl LogicalPlanBuilder {
/// * An invalid expression is used (e.g. a `sort` expression)
pub fn project(&self, expr: impl IntoIterator<Item = Expr>) ->
Result<Self> {
let input_schema = self.plan.schema();
- let all_schemas = self.plan.all_schemas();
let mut projected_expr = vec![];
for e in expr {
match e {
@@ -227,10 +225,8 @@ impl LogicalPlanBuilder {
.push(Expr::Column(input_schema.field(i).qualified_column()))
});
}
- _ => projected_expr.push(columnize_expr(
- normalize_col(e, &all_schemas)?,
- input_schema,
- )),
+ _ => projected_expr
+ .push(columnize_expr(normalize_col(e, &self.plan)?,
input_schema)),
}
}
@@ -247,7 +243,7 @@ impl LogicalPlanBuilder {
/// Apply a filter
pub fn filter(&self, expr: Expr) -> Result<Self> {
- let expr = normalize_col(expr, &self.plan.all_schemas())?;
+ let expr = normalize_col(expr, &self.plan)?;
Ok(Self::from(LogicalPlan::Filter {
predicate: expr,
input: Arc::new(self.plan.clone()),
@@ -264,9 +260,8 @@ impl LogicalPlanBuilder {
/// Apply a sort
pub fn sort(&self, exprs: impl IntoIterator<Item = Expr>) -> Result<Self> {
- let schemas = self.plan.all_schemas();
Ok(Self::from(LogicalPlan::Sort {
- expr: normalize_cols(exprs, &schemas)?,
+ expr: normalize_cols(exprs, &self.plan)?,
input: Arc::new(self.plan.clone()),
}))
}
@@ -292,20 +287,15 @@ impl LogicalPlanBuilder {
let left_keys: Vec<Column> = left_keys
.into_iter()
- .map(|c| c.into().normalize(&self.plan.all_schemas()))
+ .map(|c| c.into().normalize(&self.plan))
.collect::<Result<_>>()?;
let right_keys: Vec<Column> = right_keys
.into_iter()
- .map(|c| c.into().normalize(&right.all_schemas()))
+ .map(|c| c.into().normalize(right))
.collect::<Result<_>>()?;
let on: Vec<(_, _)> =
left_keys.into_iter().zip(right_keys.into_iter()).collect();
- let join_schema = build_join_schema(
- self.plan.schema(),
- right.schema(),
- &on,
- &join_type,
- &JoinConstraint::On,
- )?;
+ let join_schema =
+ build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
Ok(Self::from(LogicalPlan::Join {
left: Arc::new(self.plan.clone()),
@@ -327,21 +317,16 @@ impl LogicalPlanBuilder {
let left_keys: Vec<Column> = using_keys
.clone()
.into_iter()
- .map(|c| c.into().normalize(&self.plan.all_schemas()))
+ .map(|c| c.into().normalize(&self.plan))
.collect::<Result<_>>()?;
let right_keys: Vec<Column> = using_keys
.into_iter()
- .map(|c| c.into().normalize(&right.all_schemas()))
+ .map(|c| c.into().normalize(right))
.collect::<Result<_>>()?;
let on: Vec<(_, _)> =
left_keys.into_iter().zip(right_keys.into_iter()).collect();
- let join_schema = build_join_schema(
- self.plan.schema(),
- right.schema(),
- &on,
- &join_type,
- &JoinConstraint::Using,
- )?;
+ let join_schema =
+ build_join_schema(self.plan.schema(), right.schema(), &join_type)?;
Ok(Self::from(LogicalPlan::Join {
left: Arc::new(self.plan.clone()),
@@ -394,9 +379,8 @@ impl LogicalPlanBuilder {
group_expr: impl IntoIterator<Item = Expr>,
aggr_expr: impl IntoIterator<Item = Expr>,
) -> Result<Self> {
- let schemas = self.plan.all_schemas();
- let group_expr = normalize_cols(group_expr, &schemas)?;
- let aggr_expr = normalize_cols(aggr_expr, &schemas)?;
+ let group_expr = normalize_cols(group_expr, &self.plan)?;
+ let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
let all_expr = group_expr.iter().chain(aggr_expr.iter());
validate_unique_names("Aggregations", all_expr.clone(),
self.plan.schema())?;
@@ -440,33 +424,12 @@ impl LogicalPlanBuilder {
pub fn build_join_schema(
left: &DFSchema,
right: &DFSchema,
- on: &[(Column, Column)],
join_type: &JoinType,
- join_constraint: &JoinConstraint,
) -> Result<DFSchema> {
let fields: Vec<DFField> = match join_type {
- JoinType::Inner | JoinType::Left | JoinType::Full => {
- let duplicate_keys = match join_constraint {
- JoinConstraint::On => on
- .iter()
- .filter(|(l, r)| l == r)
- .map(|on| on.1.clone())
- .collect::<HashSet<_>>(),
- // using join requires unique join columns in the output
schema, so we mark all
- // right join keys as duplicate
- JoinConstraint::Using => {
- on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>()
- }
- };
-
+ JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right =>
{
+ let right_fields = right.fields().iter();
let left_fields = left.fields().iter();
-
- // remove right-side join keys if they have the same names as the
left-side
- let right_fields = right
- .fields()
- .iter()
- .filter(|f| !duplicate_keys.contains(&f.qualified_column()));
-
// left then right
left_fields.chain(right_fields).cloned().collect()
}
@@ -474,31 +437,6 @@ pub fn build_join_schema(
// Only use the left side for the schema
left.fields().clone()
}
- JoinType::Right => {
- let duplicate_keys = match join_constraint {
- JoinConstraint::On => on
- .iter()
- .filter(|(l, r)| l == r)
- .map(|on| on.1.clone())
- .collect::<HashSet<_>>(),
- // using join requires unique join columns in the output
schema, so we mark all
- // left join keys as duplicate
- JoinConstraint::Using => {
- on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>()
- }
- };
-
- // remove left-side join keys if they have the same names as the
right-side
- let left_fields = left
- .fields()
- .iter()
- .filter(|f| !duplicate_keys.contains(&f.qualified_column()));
-
- let right_fields = right.fields().iter();
-
- // left then right
- left_fields.chain(right_fields).cloned().collect()
- }
};
DFSchema::new(fields)
diff --git a/datafusion/src/logical_plan/dfschema.rs
b/datafusion/src/logical_plan/dfschema.rs
index b4d864f..b4bde87 100644
--- a/datafusion/src/logical_plan/dfschema.rs
+++ b/datafusion/src/logical_plan/dfschema.rs
@@ -48,6 +48,7 @@ impl DFSchema {
pub fn new(fields: Vec<DFField>) -> Result<Self> {
let mut qualified_names = HashSet::new();
let mut unqualified_names = HashSet::new();
+
for field in &fields {
if let Some(qualifier) = field.qualifier() {
if !qualified_names.insert((qualifier, field.name())) {
@@ -94,10 +95,7 @@ impl DFSchema {
schema
.fields()
.iter()
- .map(|f| DFField {
- field: f.clone(),
- qualifier: Some(qualifier.to_owned()),
- })
+ .map(|f| DFField::from_qualified(qualifier, f.clone()))
.collect(),
)
}
@@ -149,47 +147,80 @@ impl DFSchema {
)))
}
- /// Find the index of the column with the given qualifer and name
- pub fn index_of_column(&self, col: &Column) -> Result<usize> {
- for i in 0..self.fields.len() {
- let field = &self.fields[i];
- if field.qualifier() == col.relation.as_ref() && field.name() ==
&col.name {
- return Ok(i);
- }
+ fn index_of_column_by_name(
+ &self,
+ qualifier: Option<&str>,
+ name: &str,
+ ) -> Result<usize> {
+ let matches: Vec<usize> = self
+ .fields
+ .iter()
+ .enumerate()
+ .filter(|(_, field)| match (qualifier, &field.qualifier) {
+ // field to lookup is qualified.
+ // current field is qualified and not shared between
relations, compare both
+ // qualifer and name.
+ (Some(q), Some(field_q)) => q == field_q && field.name() ==
name,
+ // field to lookup is qualified but current field is
unqualified.
+ (Some(_), None) => false,
+ // field to lookup is unqualified, no need to compare qualifier
+ _ => field.name() == name,
+ })
+ .map(|(idx, _)| idx)
+ .collect();
+
+ match matches.len() {
+ 0 => Err(DataFusionError::Plan(format!(
+ "No field named '{}.{}'. Valid fields are {}.",
+ qualifier.unwrap_or(""),
+ name,
+ self.get_field_names()
+ ))),
+ 1 => Ok(matches[0]),
+ _ => Err(DataFusionError::Internal(format!(
+ "Ambiguous reference to qualified field named '{}.{}'",
+ qualifier.unwrap_or(""),
+ name
+ ))),
}
- Err(DataFusionError::Plan(format!(
- "No field matches column '{}'. Available fields: {}",
- col, self
- )))
+ }
+
+ /// Find the index of the column with the given qualifier and name
+ pub fn index_of_column(&self, col: &Column) -> Result<usize> {
+ self.index_of_column_by_name(col.relation.as_deref(), &col.name)
}
/// Find the field with the given name
pub fn field_with_name(
&self,
- relation_name: Option<&str>,
+ qualifier: Option<&str>,
name: &str,
- ) -> Result<DFField> {
- if let Some(relation_name) = relation_name {
- self.field_with_qualified_name(relation_name, name)
+ ) -> Result<&DFField> {
+ if let Some(qualifier) = qualifier {
+ self.field_with_qualified_name(qualifier, name)
} else {
self.field_with_unqualified_name(name)
}
}
- /// Find the field with the given name
- pub fn field_with_unqualified_name(&self, name: &str) -> Result<DFField> {
- let matches: Vec<&DFField> = self
- .fields
+ /// Find all fields match the given name
+ pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> {
+ self.fields
.iter()
.filter(|field| field.name() == name)
- .collect();
+ .collect()
+ }
+
+ /// Find the field with the given name
+ pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> {
+ let matches = self.fields_with_unqualified_name(name);
match matches.len() {
0 => Err(DataFusionError::Plan(format!(
"No field with unqualified name '{}'. Valid fields are {}.",
name,
self.get_field_names()
))),
- 1 => Ok(matches[0].to_owned()),
+ 1 => Ok(matches[0]),
_ => Err(DataFusionError::Plan(format!(
"Ambiguous reference to field named '{}'",
name
@@ -200,33 +231,15 @@ impl DFSchema {
/// Find the field with the given qualified name
pub fn field_with_qualified_name(
&self,
- relation_name: &str,
+ qualifier: &str,
name: &str,
- ) -> Result<DFField> {
- let matches: Vec<&DFField> = self
- .fields
- .iter()
- .filter(|field| {
- field.qualifier == Some(relation_name.to_string()) &&
field.name() == name
- })
- .collect();
- match matches.len() {
- 0 => Err(DataFusionError::Plan(format!(
- "No field named '{}.{}'. Valid fields are {}.",
- relation_name,
- name,
- self.get_field_names()
- ))),
- 1 => Ok(matches[0].to_owned()),
- _ => Err(DataFusionError::Internal(format!(
- "Ambiguous reference to qualified field named '{}.{}'",
- relation_name, name
- ))),
- }
+ ) -> Result<&DFField> {
+ let idx = self.index_of_column_by_name(Some(qualifier), name)?;
+ Ok(self.field(idx))
}
/// Find the field with the given qualified column
- pub fn field_from_qualified_column(&self, column: &Column) ->
Result<DFField> {
+ pub fn field_from_column(&self, column: &Column) -> Result<&DFField> {
match &column.relation {
Some(r) => self.field_with_qualified_name(r, &column.name),
None => self.field_with_unqualified_name(&column.name),
@@ -247,31 +260,20 @@ impl DFSchema {
fields: self
.fields
.into_iter()
- .map(|f| {
- if f.qualifier().is_some() {
- DFField::new(
- None,
- f.name(),
- f.data_type().to_owned(),
- f.is_nullable(),
- )
- } else {
- f
- }
- })
+ .map(|f| f.strip_qualifier())
.collect(),
}
}
/// Replace all field qualifier with new value in schema
- pub fn replace_qualifier(self, qualifer: &str) -> Self {
+ pub fn replace_qualifier(self, qualifier: &str) -> Self {
DFSchema {
fields: self
.fields
.into_iter()
.map(|f| {
DFField::new(
- Some(qualifer),
+ Some(qualifier),
f.name(),
f.data_type().to_owned(),
f.is_nullable(),
@@ -328,10 +330,7 @@ impl TryFrom<Schema> for DFSchema {
schema
.fields()
.iter()
- .map(|f| DFField {
- field: f.clone(),
- qualifier: None,
- })
+ .map(|f| DFField::from(f.clone()))
.collect(),
)
}
@@ -454,8 +453,8 @@ impl DFField {
/// Returns a string to the `DFField`'s qualified name
pub fn qualified_name(&self) -> String {
- if let Some(relation_name) = &self.qualifier {
- format!("{}.{}", relation_name, self.field.name())
+ if let Some(qualifier) = &self.qualifier {
+ format!("{}.{}", qualifier, self.field.name())
} else {
self.field.name().to_owned()
}
@@ -469,6 +468,14 @@ impl DFField {
}
}
+ /// Builds an unqualified column based on self
+ pub fn unqualified_column(&self) -> Column {
+ Column {
+ relation: None,
+ name: self.field.name().to_string(),
+ }
+ }
+
/// Get the optional qualifier
pub fn qualifier(&self) -> Option<&String> {
self.qualifier.as_ref()
@@ -478,6 +485,12 @@ impl DFField {
pub fn field(&self) -> &Field {
&self.field
}
+
+ /// Return field with qualifier stripped
+ pub fn strip_qualifier(mut self) -> Self {
+ self.qualifier = None;
+ self
+ }
}
#[cfg(test)]
diff --git a/datafusion/src/logical_plan/expr.rs
b/datafusion/src/logical_plan/expr.rs
index 1fab9bb..9454d75 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -20,7 +20,7 @@
pub use super::Operator;
use crate::error::{DataFusionError, Result};
-use crate::logical_plan::{window_frames, DFField, DFSchema, DFSchemaRef};
+use crate::logical_plan::{window_frames, DFField, DFSchema, LogicalPlan};
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions,
udf::ScalarUDF,
window_functions,
@@ -29,7 +29,7 @@ use crate::{physical_plan::udaf::AggregateUDF,
scalar::ScalarValue};
use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
use arrow::{compute::can_cast_types, datatypes::DataType};
use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature};
-use std::collections::HashSet;
+use std::collections::{HashMap, HashSet};
use std::fmt;
use std::sync::Arc;
@@ -89,14 +89,46 @@ impl Column {
///
/// For example, `foo` will be normalized to `t.foo` if there is a
/// column named `foo` in a relation named `t` found in `schemas`
- pub fn normalize(self, schemas: &[&DFSchemaRef]) -> Result<Self> {
+ pub fn normalize(self, plan: &LogicalPlan) -> Result<Self> {
if self.relation.is_some() {
return Ok(self);
}
- for schema in schemas {
- if let Ok(field) = schema.field_with_unqualified_name(&self.name) {
- return Ok(field.qualified_column());
+ let schemas = plan.all_schemas();
+ let using_columns = plan.using_columns()?;
+
+ for schema in &schemas {
+ let fields = schema.fields_with_unqualified_name(&self.name);
+ match fields.len() {
+ 0 => continue,
+ 1 => {
+ return Ok(fields[0].qualified_column());
+ }
+ _ => {
+ // More than 1 fields in this schema have their names set
to self.name.
+ //
+ // This should only happen when a JOIN query with USING
constraint references
+ // join columns using unqualified column name. For example:
+ //
+ // ```sql
+ // SELECT id FROM t1 JOIN t2 USING(id)
+ // ```
+ //
+ // In this case, both `t1.id` and `t2.id` will match
unqualified column `id`.
+ // We will use the relation from the first matched field
to normalize self.
+
+ // Compare matched fields with one USING JOIN clause at a
time
+ for using_col in &using_columns {
+ let all_matched = fields
+ .iter()
+ .all(|f|
using_col.contains(&f.qualified_column()));
+ // All matched fields belong to the same using column
set, in orther words
+ // the same join clause. We simply pick the qualifer
from the first match.
+ if all_matched {
+ return Ok(fields[0].qualified_column());
+ }
+ }
+ }
}
}
@@ -321,9 +353,7 @@ impl Expr {
pub fn get_type(&self, schema: &DFSchema) -> Result<DataType> {
match self {
Expr::Alias(expr, _) => expr.get_type(schema),
- Expr::Column(c) => {
- Ok(schema.field_from_qualified_column(c)?.data_type().clone())
- }
+ Expr::Column(c) =>
Ok(schema.field_from_column(c)?.data_type().clone()),
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case { when_then_expr, .. } =>
when_then_expr[0].1.get_type(schema),
@@ -395,9 +425,7 @@ impl Expr {
pub fn nullable(&self, input_schema: &DFSchema) -> Result<bool> {
match self {
Expr::Alias(expr, _) => expr.nullable(input_schema),
- Expr::Column(c) => {
- Ok(input_schema.field_from_qualified_column(c)?.is_nullable())
- }
+ Expr::Column(c) =>
Ok(input_schema.field_from_column(c)?.is_nullable()),
Expr::Literal(value) => Ok(value.is_null()),
Expr::ScalarVariable(_) => Ok(true),
Expr::Case {
@@ -1118,36 +1146,56 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema)
-> Expr {
}
}
+/// Recursively replace all Column expressions in a given expression tree with
Column expressions
+/// provided by the hash map argument.
+pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) ->
Result<Expr> {
+ struct ColumnReplacer<'a> {
+ replace_map: &'a HashMap<&'a Column, &'a Column>,
+ }
+
+ impl<'a> ExprRewriter for ColumnReplacer<'a> {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ if let Expr::Column(c) = &expr {
+ match self.replace_map.get(c) {
+ Some(new_c) => Ok(Expr::Column((*new_c).to_owned())),
+ None => Ok(expr),
+ }
+ } else {
+ Ok(expr)
+ }
+ }
+ }
+
+ e.rewrite(&mut ColumnReplacer { replace_map })
+}
+
/// Recursively call [`Column::normalize`] on all Column expressions
/// in the `expr` expression tree.
-pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result<Expr> {
- struct ColumnNormalizer<'a, 'b> {
- schemas: &'a [&'b DFSchemaRef],
+pub fn normalize_col(e: Expr, plan: &LogicalPlan) -> Result<Expr> {
+ struct ColumnNormalizer<'a> {
+ plan: &'a LogicalPlan,
}
- impl<'a, 'b> ExprRewriter for ColumnNormalizer<'a, 'b> {
+ impl<'a> ExprRewriter for ColumnNormalizer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
if let Expr::Column(c) = expr {
- Ok(Expr::Column(c.normalize(self.schemas)?))
+ Ok(Expr::Column(c.normalize(self.plan)?))
} else {
Ok(expr)
}
}
}
- e.rewrite(&mut ColumnNormalizer { schemas })
+ e.rewrite(&mut ColumnNormalizer { plan })
}
/// Recursively normalize all Column expressions in a list of expression trees
#[inline]
pub fn normalize_cols(
exprs: impl IntoIterator<Item = Expr>,
- schemas: &[&DFSchemaRef],
+ plan: &LogicalPlan,
) -> Result<Vec<Expr>> {
- exprs
- .into_iter()
- .map(|e| normalize_col(e, schemas))
- .collect()
+ exprs.into_iter().map(|e| normalize_col(e, plan)).collect()
}
/// Create an expression to represent the min() aggregate function
diff --git a/datafusion/src/logical_plan/mod.rs
b/datafusion/src/logical_plan/mod.rs
index 69d03d2..86a2f56 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -41,10 +41,10 @@ pub use expr::{
cos, count, count_distinct, create_udaf, create_udf, exp,
exprlist_to_fields, floor,
in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim,
max, md5,
min, normalize_col, normalize_cols, now, octet_length, or, random,
regexp_match,
- regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim,
sha224, sha256,
- sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos,
substr, sum, tan,
- to_hex, translate, trim, trunc, upper, when, Column, Expr, ExprRewriter,
- ExpressionVisitor, Literal, Recursion,
+ regexp_replace, repeat, replace, replace_col, reverse, right, round, rpad,
rtrim,
+ sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
starts_with, strpos,
+ substr, sum, tan, to_hex, translate, trim, trunc, upper, when, Column,
Expr,
+ ExprRewriter, ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
diff --git a/datafusion/src/logical_plan/plan.rs
b/datafusion/src/logical_plan/plan.rs
index 99f0fa1..b954b6a 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -21,9 +21,11 @@ use super::display::{GraphvizVisitor, IndentVisitor};
use super::expr::{Column, Expr};
use super::extension::UserDefinedLogicalNode;
use crate::datasource::TableProvider;
+use crate::error::DataFusionError;
use crate::logical_plan::dfschema::DFSchemaRef;
use crate::sql::parser::FileType;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use std::collections::HashSet;
use std::{
fmt::{self, Display},
sync::Arc,
@@ -354,6 +356,43 @@ impl LogicalPlan {
| LogicalPlan::CreateExternalTable { .. } => vec![],
}
}
+
+ /// returns all `Using` join columns in a logical plan
+ pub fn using_columns(&self) -> Result<Vec<HashSet<Column>>,
DataFusionError> {
+ struct UsingJoinColumnVisitor {
+ using_columns: Vec<HashSet<Column>>,
+ }
+
+ impl PlanVisitor for UsingJoinColumnVisitor {
+ type Error = DataFusionError;
+
+ fn pre_visit(&mut self, plan: &LogicalPlan) -> Result<bool,
Self::Error> {
+ if let LogicalPlan::Join {
+ join_constraint: JoinConstraint::Using,
+ on,
+ ..
+ } = plan
+ {
+ self.using_columns.push(
+ on.iter()
+ .map(|entry| {
+ std::iter::once(entry.0.clone())
+ .chain(std::iter::once(entry.1.clone()))
+ })
+ .flatten()
+ .collect::<HashSet<_>>(),
+ );
+ }
+ Ok(true)
+ }
+ }
+
+ let mut visitor = UsingJoinColumnVisitor {
+ using_columns: vec![],
+ };
+ self.accept(&mut visitor)?;
+ Ok(visitor.using_columns)
+ }
}
/// Logical partitioning schemes supported by the repartition operator.
@@ -709,10 +748,21 @@ impl LogicalPlan {
}
Ok(())
}
- LogicalPlan::Join { on: ref keys, .. } => {
+ LogicalPlan::Join {
+ on: ref keys,
+ join_constraint,
+ ..
+ } => {
let join_expr: Vec<String> =
keys.iter().map(|(l, r)| format!("{} = {}", l,
r)).collect();
- write!(f, "Join: {}", join_expr.join(", "))
+ match join_constraint {
+ JoinConstraint::On => {
+ write!(f, "Join: {}", join_expr.join(", "))
+ }
+ JoinConstraint::Using => {
+ write!(f, "Join: Using {}", join_expr.join(",
"))
+ }
+ }
}
LogicalPlan::CrossJoin { .. } => {
write!(f, "CrossJoin:")
diff --git a/datafusion/src/optimizer/filter_push_down.rs
b/datafusion/src/optimizer/filter_push_down.rs
index c1d81fe..76d8c05 100644
--- a/datafusion/src/optimizer/filter_push_down.rs
+++ b/datafusion/src/optimizer/filter_push_down.rs
@@ -16,7 +16,7 @@
use crate::datasource::datasource::TableProviderFilterPushDown;
use crate::execution::context::ExecutionProps;
-use crate::logical_plan::{and, Column, LogicalPlan};
+use crate::logical_plan::{and, replace_col, Column, LogicalPlan};
use crate::logical_plan::{DFSchema, Expr};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
@@ -96,12 +96,21 @@ fn get_join_predicates<'a>(
let left_columns = &left
.fields()
.iter()
- .map(|f| f.qualified_column())
+ .map(|f| {
+ std::iter::once(f.qualified_column())
+ // we need to push down filter using unqualified column as well
+ .chain(std::iter::once(f.unqualified_column()))
+ })
+ .flatten()
.collect::<HashSet<_>>();
let right_columns = &right
.fields()
.iter()
- .map(|f| f.qualified_column())
+ .map(|f| {
+ std::iter::once(f.qualified_column())
+ .chain(std::iter::once(f.unqualified_column()))
+ })
+ .flatten()
.collect::<HashSet<_>>();
let filters = state
@@ -232,6 +241,38 @@ fn split_members<'a>(predicate: &'a Expr, predicates: &mut
Vec<&'a Expr>) {
}
}
+fn optimize_join(
+ mut state: State,
+ plan: &LogicalPlan,
+ left: &LogicalPlan,
+ right: &LogicalPlan,
+) -> Result<LogicalPlan> {
+ let (pushable_to_left, pushable_to_right, keep) =
+ get_join_predicates(&state, left.schema(), right.schema());
+
+ let mut left_state = state.clone();
+ left_state.filters = keep_filters(&left_state.filters, &pushable_to_left);
+ let left = optimize(left, left_state)?;
+
+ let mut right_state = state.clone();
+ right_state.filters = keep_filters(&right_state.filters,
&pushable_to_right);
+ let right = optimize(right, right_state)?;
+
+ // create a new Join with the new `left` and `right`
+ let expr = plan.expressions();
+ let plan = utils::from_plan(plan, &expr, &[left, right])?;
+
+ if keep.0.is_empty() {
+ Ok(plan)
+ } else {
+ // wrap the join on the filter whose predicates must be kept
+ let plan = add_filter(plan, &keep.0);
+ state.filters = remove_filters(&state.filters, &keep.1);
+
+ Ok(plan)
+ }
+}
+
fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Explain { .. } => {
@@ -336,32 +377,68 @@ fn optimize(plan: &LogicalPlan, mut state: State) ->
Result<LogicalPlan> {
.collect::<HashSet<_>>();
issue_filters(state, used_columns, plan)
}
- LogicalPlan::Join { left, right, .. }
- | LogicalPlan::CrossJoin { left, right, .. } => {
- let (pushable_to_left, pushable_to_right, keep) =
- get_join_predicates(&state, left.schema(), right.schema());
-
- let mut left_state = state.clone();
- left_state.filters = keep_filters(&left_state.filters,
&pushable_to_left);
- let left = optimize(left, left_state)?;
-
- let mut right_state = state.clone();
- right_state.filters = keep_filters(&right_state.filters,
&pushable_to_right);
- let right = optimize(right, right_state)?;
-
- // create a new Join with the new `left` and `right`
- let expr = plan.expressions();
- let plan = utils::from_plan(plan, &expr, &[left, right])?;
+ LogicalPlan::CrossJoin { left, right, .. } => {
+ optimize_join(state, plan, left, right)
+ }
+ LogicalPlan::Join {
+ left, right, on, ..
+ } => {
+ // duplicate filters for joined columns so filters can be pushed
down to both sides.
+ // Take the following query as an example:
+ //
+ // ```sql
+ // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
+ // ```
+ //
+ // `t1.id > 1` predicate needs to be pushed down to t1 table scan,
while
+ // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
+ //
+ // Join clauses with `Using` constraints also take advantage of
this logic to make sure
+ // predicates reference the shared join columns are pushed to both
sides.
+ let join_side_filters = state
+ .filters
+ .iter()
+ .filter_map(|(predicate, columns)| {
+ let mut join_cols_to_replace = HashMap::new();
+ for col in columns.iter() {
+ for (l, r) in on {
+ if col == l {
+ join_cols_to_replace.insert(col, r);
+ break;
+ } else if col == r {
+ join_cols_to_replace.insert(col, l);
+ break;
+ }
+ }
+ }
- if keep.0.is_empty() {
- Ok(plan)
- } else {
- // wrap the join on the filter whose predicates must be kept
- let plan = add_filter(plan, &keep.0);
- state.filters = remove_filters(&state.filters, &keep.1);
+ if join_cols_to_replace.is_empty() {
+ return None;
+ }
- Ok(plan)
- }
+ let join_side_predicate =
+ match replace_col(predicate.clone(),
&join_cols_to_replace) {
+ Ok(p) => p,
+ Err(e) => {
+ return Some(Err(e));
+ }
+ };
+
+ let join_side_columns = columns
+ .clone()
+ .into_iter()
+ // replace keys in join_cols_to_replace with values in
resulting column
+ // set
+ .filter(|c| !join_cols_to_replace.contains_key(c))
+ .chain(join_cols_to_replace.iter().map(|(_, v)|
(*v).clone()))
+ .collect();
+
+ Some(Ok((join_side_predicate, join_side_columns)))
+ })
+ .collect::<Result<Vec<_>>>()?;
+ state.filters.extend(join_side_filters);
+
+ optimize_join(state, plan, left, right)
}
LogicalPlan::TableScan {
source,
@@ -878,12 +955,13 @@ mod tests {
Ok(())
}
- /// post-join predicates on a column common to both sides is pushed to
both sides
+ /// post-on-join predicates on a column common to both sides is pushed to
both sides
#[test]
- fn filter_join_on_common_independent() -> Result<()> {
+ fn filter_on_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
- let left = LogicalPlanBuilder::from(table_scan.clone()).build()?;
- let right = LogicalPlanBuilder::from(table_scan)
+ let left = LogicalPlanBuilder::from(table_scan).build()?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
@@ -901,20 +979,61 @@ mod tests {
format!("{:?}", plan),
"\
Filter: #test.a LtEq Int64(1)\
- \n Join: #test.a = #test.a\
+ \n Join: #test.a = #test2.a\
\n TableScan: test projection=None\
- \n Projection: #test.a\
- \n TableScan: test projection=None"
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None"
);
// filter sent to side before the join
let expected = "\
- Join: #test.a = #test.a\
+ Join: #test.a = #test2.a\
\n Filter: #test.a LtEq Int64(1)\
\n TableScan: test projection=None\
- \n Projection: #test.a\
- \n Filter: #test.a LtEq Int64(1)\
- \n TableScan: test projection=None";
+ \n Projection: #test2.a\
+ \n Filter: #test2.a LtEq Int64(1)\
+ \n TableScan: test2 projection=None";
+ assert_optimized_plan_eq(&plan, expected);
+ Ok(())
+ }
+
+ /// post-using-join predicates on a column common to both sides is pushed
to both sides
+ #[test]
+ fn filter_using_join_on_common_independent() -> Result<()> {
+ let table_scan = test_table_scan()?;
+ let left = LogicalPlanBuilder::from(table_scan).build()?;
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
+ .project(vec![col("a")])?
+ .build()?;
+ let plan = LogicalPlanBuilder::from(left)
+ .join_using(
+ &right,
+ JoinType::Inner,
+ vec![Column::from_name("a".to_string())],
+ )?
+ .filter(col("a").lt_eq(lit(1i64)))?
+ .build()?;
+
+ // not part of the test, just good to know:
+ assert_eq!(
+ format!("{:?}", plan),
+ "\
+ Filter: #test.a LtEq Int64(1)\
+ \n Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n TableScan: test2 projection=None"
+ );
+
+ // filter sent to side before the join
+ let expected = "\
+ Join: Using #test.a = #test2.a\
+ \n Filter: #test.a LtEq Int64(1)\
+ \n TableScan: test projection=None\
+ \n Projection: #test2.a\
+ \n Filter: #test2.a LtEq Int64(1)\
+ \n TableScan: test2 projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
@@ -923,10 +1042,11 @@ mod tests {
#[test]
fn filter_join_on_common_dependent() -> Result<()> {
let table_scan = test_table_scan()?;
- let left = LogicalPlanBuilder::from(table_scan.clone())
+ let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("c")])?
.build()?;
- let right = LogicalPlanBuilder::from(table_scan)
+ let right_table_scan = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(right_table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(left)
@@ -944,12 +1064,12 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
"\
- Filter: #test.c LtEq #test.b\
- \n Join: #test.a = #test.a\
+ Filter: #test.c LtEq #test2.b\
+ \n Join: #test.a = #test2.a\
\n Projection: #test.a, #test.c\
\n TableScan: test projection=None\
- \n Projection: #test.a, #test.b\
- \n TableScan: test projection=None"
+ \n Projection: #test2.a, #test2.b\
+ \n TableScan: test2 projection=None"
);
// expected is equal: no push-down
@@ -962,12 +1082,14 @@ mod tests {
#[test]
fn filter_join_on_one_side() -> Result<()> {
let table_scan = test_table_scan()?;
- let left = LogicalPlanBuilder::from(table_scan.clone())
+ let left = LogicalPlanBuilder::from(table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
- let right = LogicalPlanBuilder::from(table_scan)
+ let table_scan_right = test_table_scan_with_name("test2")?;
+ let right = LogicalPlanBuilder::from(table_scan_right)
.project(vec![col("a"), col("c")])?
.build()?;
+
let plan = LogicalPlanBuilder::from(left)
.join(
&right,
@@ -983,20 +1105,20 @@ mod tests {
format!("{:?}", plan),
"\
Filter: #test.b LtEq Int64(1)\
- \n Join: #test.a = #test.a\
+ \n Join: #test.a = #test2.a\
\n Projection: #test.a, #test.b\
\n TableScan: test projection=None\
- \n Projection: #test.a, #test.c\
- \n TableScan: test projection=None"
+ \n Projection: #test2.a, #test2.c\
+ \n TableScan: test2 projection=None"
);
let expected = "\
- Join: #test.a = #test.a\
+ Join: #test.a = #test2.a\
\n Projection: #test.a, #test.b\
\n Filter: #test.b LtEq Int64(1)\
\n TableScan: test projection=None\
- \n Projection: #test.a, #test.c\
- \n TableScan: test projection=None";
+ \n Projection: #test2.a, #test2.c\
+ \n TableScan: test2 projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
diff --git a/datafusion/src/optimizer/projection_push_down.rs
b/datafusion/src/optimizer/projection_push_down.rs
index 3c8f1ee..0272b9f 100644
--- a/datafusion/src/optimizer/projection_push_down.rs
+++ b/datafusion/src/optimizer/projection_push_down.rs
@@ -216,9 +216,7 @@ fn optimize_plan(
let schema = build_join_schema(
optimized_left.schema(),
optimized_right.schema(),
- on,
join_type,
- join_constraint,
)?;
Ok(LogicalPlan::Join {
@@ -499,7 +497,7 @@ mod tests {
}
#[test]
- fn join_schema_trim() -> Result<()> {
+ fn join_schema_trim_full_join_column_projection() -> Result<()> {
let table_scan = test_table_scan()?;
let schema = Schema::new(vec![Field::new("c1", DataType::UInt32,
false)]);
@@ -511,7 +509,7 @@ mod tests {
.project(vec![col("a"), col("b"), col("c1")])?
.build()?;
- // make sure projections are pushed down to table scan
+ // make sure projections are pushed down to both table scans
let expected = "Projection: #test.a, #test.b, #test2.c1\
\n Join: #test.a = #test2.c1\
\n TableScan: test projection=Some([0, 1])\
@@ -521,7 +519,48 @@ mod tests {
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
- // make sure schema for join node doesn't include c1 column
+ // make sure schema for join node include both join columns
+ let optimized_join = optimized_plan.inputs()[0];
+ assert_eq!(
+ **optimized_join.schema(),
+ DFSchema::new(vec![
+ DFField::new(Some("test"), "a", DataType::UInt32, false),
+ DFField::new(Some("test"), "b", DataType::UInt32, false),
+ DFField::new(Some("test2"), "c1", DataType::UInt32, false),
+ ])?,
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn join_schema_trim_partial_join_column_projection() -> Result<()> {
+ // test join column push down without explicit column projections
+
+ let table_scan = test_table_scan()?;
+
+ let schema = Schema::new(vec![Field::new("c1", DataType::UInt32,
false)]);
+ let table2_scan =
+ LogicalPlanBuilder::scan_empty(Some("test2"), &schema,
None)?.build()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])?
+ // projecting joined column `a` should push the right side column
`c1` projection as
+ // well into test2 table even though `c1` is not referenced in
projection.
+ .project(vec![col("a"), col("b")])?
+ .build()?;
+
+ // make sure projections are pushed down to both table scans
+ let expected = "Projection: #test.a, #test.b\
+ \n Join: #test.a = #test2.c1\
+ \n TableScan: test projection=Some([0, 1])\
+ \n TableScan: test2 projection=Some([0])";
+
+ let optimized_plan = optimize(&plan)?;
+ let formatted_plan = format!("{:?}", optimized_plan);
+ assert_eq!(formatted_plan, expected);
+
+ // make sure schema for join node include both join columns
let optimized_join = optimized_plan.inputs()[0];
assert_eq!(
**optimized_join.schema(),
@@ -536,6 +575,45 @@ mod tests {
}
#[test]
+ fn join_schema_trim_using_join() -> Result<()> {
+ // shared join colums from using join should be pushed to both sides
+
+ let table_scan = test_table_scan()?;
+
+ let schema = Schema::new(vec![Field::new("a", DataType::UInt32,
false)]);
+ let table2_scan =
+ LogicalPlanBuilder::scan_empty(Some("test2"), &schema,
None)?.build()?;
+
+ let plan = LogicalPlanBuilder::from(table_scan)
+ .join_using(&table2_scan, JoinType::Left, vec!["a"])?
+ .project(vec![col("a"), col("b")])?
+ .build()?;
+
+ // make sure projections are pushed down to table scan
+ let expected = "Projection: #test.a, #test.b\
+ \n Join: Using #test.a = #test2.a\
+ \n TableScan: test projection=Some([0, 1])\
+ \n TableScan: test2 projection=Some([0])";
+
+ let optimized_plan = optimize(&plan)?;
+ let formatted_plan = format!("{:?}", optimized_plan);
+ assert_eq!(formatted_plan, expected);
+
+ // make sure schema for join node include both join columns
+ let optimized_join = optimized_plan.inputs()[0];
+ assert_eq!(
+ **optimized_join.schema(),
+ DFSchema::new(vec![
+ DFField::new(Some("test"), "a", DataType::UInt32, false),
+ DFField::new(Some("test"), "b", DataType::UInt32, false),
+ DFField::new(Some("test2"), "a", DataType::UInt32, false),
+ ])?,
+ );
+
+ Ok(())
+ }
+
+ #[test]
fn cast() -> Result<()> {
let table_scan = test_table_scan()?;
diff --git a/datafusion/src/optimizer/utils.rs
b/datafusion/src/optimizer/utils.rs
index ae3e196..1d19f06 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -215,13 +215,8 @@ pub fn from_plan(
on,
..
} => {
- let schema = build_join_schema(
- inputs[0].schema(),
- inputs[1].schema(),
- on,
- join_type,
- join_constraint,
- )?;
+ let schema =
+ build_join_schema(inputs[0].schema(), inputs[1].schema(),
join_type)?;
Ok(LogicalPlan::Join {
left: Arc::new(inputs[0].clone()),
right: Arc::new(inputs[1].clone()),
diff --git a/datafusion/src/physical_plan/hash_join.rs
b/datafusion/src/physical_plan/hash_join.rs
index f426bc9..00ca153 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -55,9 +55,10 @@ use arrow::array::{
use super::expressions::Column;
use super::{
coalesce_partitions::CoalescePartitionsExec,
- hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType},
+ hash_utils::{build_join_schema, check_join_is_valid, JoinOn},
};
use crate::error::{DataFusionError, Result};
+use crate::logical_plan::JoinType;
use super::{
DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
@@ -165,12 +166,7 @@ impl HashJoinExec {
let right_schema = right.schema();
check_join_is_valid(&left_schema, &right_schema, &on)?;
- let schema = Arc::new(build_join_schema(
- &left_schema,
- &right_schema,
- &on,
- join_type,
- ));
+ let schema = Arc::new(build_join_schema(&left_schema, &right_schema,
join_type));
let random_state = RandomState::with_seeds(0, 0, 0, 0);
@@ -1437,16 +1433,16 @@ mod tests {
join_collect(left.clone(), right.clone(), on.clone(),
&JoinType::Inner)
.await?;
- assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | c2 |",
- "+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 70 |",
- "| 2 | 5 | 8 | 20 | 80 |",
- "| 3 | 5 | 9 | 20 | 80 |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 3 | 5 | 9 | 20 | 5 | 80 |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1478,16 +1474,16 @@ mod tests {
)
.await?;
- assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | c2 |",
- "+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 70 |",
- "| 2 | 5 | 8 | 20 | 80 |",
- "| 3 | 5 | 9 | 20 | 80 |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 3 | 5 | 9 | 20 | 5 | 80 |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1555,18 +1551,18 @@ mod tests {
let (columns, batches) = join_collect(left, right, on,
&JoinType::Inner).await?;
- assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]);
+ assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
assert_eq!(batches.len(), 1);
let expected = vec![
- "+----+----+----+----+",
- "| a1 | b2 | c1 | c2 |",
- "+----+----+----+----+",
- "| 1 | 1 | 7 | 70 |",
- "| 2 | 2 | 8 | 80 |",
- "| 2 | 2 | 9 | 80 |",
- "+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b2 | c1 | a1 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 1 | 7 | 1 | 1 | 70 |",
+ "| 2 | 2 | 8 | 2 | 2 | 80 |",
+ "| 2 | 2 | 9 | 2 | 2 | 80 |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1607,18 +1603,18 @@ mod tests {
let (columns, batches) = join_collect(left, right, on,
&JoinType::Inner).await?;
- assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]);
+ assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]);
assert_eq!(batches.len(), 1);
let expected = vec![
- "+----+----+----+----+",
- "| a1 | b2 | c1 | c2 |",
- "+----+----+----+----+",
- "| 1 | 1 | 7 | 70 |",
- "| 2 | 2 | 8 | 80 |",
- "| 2 | 2 | 9 | 80 |",
- "+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b2 | c1 | a1 | b2 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 1 | 7 | 1 | 1 | 70 |",
+ "| 2 | 2 | 8 | 2 | 2 | 80 |",
+ "| 2 | 2 | 9 | 2 | 2 | 80 |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1655,7 +1651,7 @@ mod tests {
let join = join(left, right, on, &JoinType::Inner)?;
let columns = columns(&join.schema());
- assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
// first part
let stream = join.execute(0).await?;
@@ -1663,11 +1659,11 @@ mod tests {
assert_eq!(batches.len(), 1);
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | c2 |",
- "+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 70 |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1676,12 +1672,12 @@ mod tests {
let batches = common::collect(stream).await?;
assert_eq!(batches.len(), 1);
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | c2 |",
- "+----+----+----+----+----+",
- "| 2 | 5 | 8 | 30 | 90 |",
- "| 3 | 5 | 9 | 30 | 90 |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 2 | 5 | 8 | 30 | 5 | 90 |",
+ "| 3 | 5 | 9 | 30 | 5 | 90 |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1721,21 +1717,21 @@ mod tests {
let join = join(left, right, on, &JoinType::Left).unwrap();
let columns = columns(&join.schema());
- assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let stream = join.execute(0).await.unwrap();
let batches = common::collect(stream).await.unwrap();
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | c2 |",
- "+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 70 |",
- "| 1 | 4 | 7 | 10 | 70 |",
- "| 2 | 5 | 8 | 20 | 80 |",
- "| 2 | 5 | 8 | 20 | 80 |",
- "| 3 | 7 | 9 | | |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 3 | 7 | 9 | | 7 | |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1801,19 +1797,19 @@ mod tests {
let join = join(left, right, on, &JoinType::Left).unwrap();
let columns = columns(&join.schema());
- assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let stream = join.execute(0).await.unwrap();
let batches = common::collect(stream).await.unwrap();
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | c2 |",
- "+----+----+----+----+----+",
- "| 1 | 4 | 7 | | |",
- "| 2 | 5 | 8 | | |",
- "| 3 | 7 | 9 | | |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | | 4 | |",
+ "| 2 | 5 | 8 | | 5 | |",
+ "| 3 | 7 | 9 | | 7 | |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1874,16 +1870,16 @@ mod tests {
let (columns, batches) =
join_collect(left.clone(), right.clone(), on.clone(),
&JoinType::Left)
.await?;
- assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | c2 |",
- "+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 70 |",
- "| 2 | 5 | 8 | 20 | 80 |",
- "| 3 | 7 | 9 | | |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 3 | 7 | 9 | | 7 | |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -1914,16 +1910,16 @@ mod tests {
&JoinType::Left,
)
.await?;
- assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | b1 | c1 | a2 | c2 |",
- "+----+----+----+----+----+",
- "| 1 | 4 | 7 | 10 | 70 |",
- "| 2 | 5 | 8 | 20 | 80 |",
- "| 3 | 7 | 9 | | |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "| 3 | 7 | 9 | | 7 | |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -2025,16 +2021,16 @@ mod tests {
let (columns, batches) = join_collect(left, right, on,
&JoinType::Right).await?;
- assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | c1 | a2 | b1 | c2 |",
- "+----+----+----+----+----+",
- "| | | 30 | 6 | 90 |",
- "| 1 | 7 | 10 | 4 | 70 |",
- "| 2 | 8 | 20 | 5 | 80 |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| | 6 | | 30 | 6 | 90 |",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
@@ -2062,16 +2058,16 @@ mod tests {
let (columns, batches) =
partitioned_join_collect(left, right, on, &JoinType::Right).await?;
- assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]);
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]);
let expected = vec![
- "+----+----+----+----+----+",
- "| a1 | c1 | a2 | b1 | c2 |",
- "+----+----+----+----+----+",
- "| | | 30 | 6 | 90 |",
- "| 1 | 7 | 10 | 4 | 70 |",
- "| 2 | 8 | 20 | 5 | 80 |",
- "+----+----+----+----+----+",
+ "+----+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+----+",
+ "| | 6 | | 30 | 6 | 90 |",
+ "| 1 | 4 | 7 | 10 | 4 | 70 |",
+ "| 2 | 5 | 8 | 20 | 5 | 80 |",
+ "+----+----+----+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
diff --git a/datafusion/src/physical_plan/hash_utils.rs
b/datafusion/src/physical_plan/hash_utils.rs
index 0cf0b92..9243aff 100644
--- a/datafusion/src/physical_plan/hash_utils.rs
+++ b/datafusion/src/physical_plan/hash_utils.rs
@@ -21,25 +21,9 @@ use crate::error::{DataFusionError, Result};
use arrow::datatypes::{Field, Schema};
use std::collections::HashSet;
+use crate::logical_plan::JoinType;
use crate::physical_plan::expressions::Column;
-/// All valid types of joins.
-#[derive(Clone, Copy, Debug, Eq, PartialEq)]
-pub enum JoinType {
- /// Inner Join
- Inner,
- /// Left Join
- Left,
- /// Right Join
- Right,
- /// Full Join
- Full,
- /// Semi Join
- Semi,
- /// Anti Join
- Anti,
-}
-
/// The on clause of the join, as vector of (left, right) columns.
pub type JoinOn = Vec<(Column, Column)>;
/// Reference for JoinOn.
@@ -104,46 +88,11 @@ fn check_join_set_is_valid(
/// Creates a schema for a join operation.
/// The fields from the left side are first
-pub fn build_join_schema(
- left: &Schema,
- right: &Schema,
- on: JoinOnRef,
- join_type: &JoinType,
-) -> Schema {
+pub fn build_join_schema(left: &Schema, right: &Schema, join_type: &JoinType)
-> Schema {
let fields: Vec<Field> = match join_type {
- JoinType::Inner | JoinType::Left | JoinType::Full => {
- // remove right-side join keys if they have the same names as the
left-side
- let duplicate_keys = &on
- .iter()
- .filter(|(l, r)| l.name() == r.name())
- .map(|on| on.1.name())
- .collect::<HashSet<_>>();
-
+ JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right =>
{
let left_fields = left.fields().iter();
-
- let right_fields = right
- .fields()
- .iter()
- .filter(|f| !duplicate_keys.contains(f.name().as_str()));
-
- // left then right
- left_fields.chain(right_fields).cloned().collect()
- }
- JoinType::Right => {
- // remove left-side join keys if they have the same names as the
right-side
- let duplicate_keys = &on
- .iter()
- .filter(|(l, r)| l.name() == r.name())
- .map(|on| on.1.name())
- .collect::<HashSet<_>>();
-
- let left_fields = left
- .fields()
- .iter()
- .filter(|f| !duplicate_keys.contains(f.name().as_str()));
-
let right_fields = right.fields().iter();
-
// left then right
left_fields.chain(right_fields).cloned().collect()
}
diff --git a/datafusion/src/physical_plan/planner.rs
b/datafusion/src/physical_plan/planner.rs
index effdefc..73b2f36 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -40,7 +40,6 @@ use crate::physical_plan::udf;
use crate::physical_plan::windows::WindowAggExec;
use crate::physical_plan::{hash_utils, Partitioning};
use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr,
WindowExpr};
-use crate::prelude::JoinType;
use crate::scalar::ScalarValue;
use crate::sql::utils::{generate_sort_key, window_expr_common_partition_keys};
use crate::variable::VarType;
@@ -661,14 +660,6 @@ impl DefaultPhysicalPlanner {
let physical_left = self.create_initial_plan(left, ctx_state)?;
let right_df_schema = right.schema();
let physical_right = self.create_initial_plan(right,
ctx_state)?;
- let physical_join_type = match join_type {
- JoinType::Inner => hash_utils::JoinType::Inner,
- JoinType::Left => hash_utils::JoinType::Left,
- JoinType::Right => hash_utils::JoinType::Right,
- JoinType::Full => hash_utils::JoinType::Full,
- JoinType::Semi => hash_utils::JoinType::Semi,
- JoinType::Anti => hash_utils::JoinType::Anti,
- };
let join_on = keys
.iter()
.map(|(l, r)| {
@@ -702,7 +693,7 @@ impl DefaultPhysicalPlanner {
Partitioning::Hash(right_expr,
ctx_state.config.concurrency),
)?),
join_on,
- &physical_join_type,
+ join_type,
PartitionMode::Partitioned,
)?))
} else {
@@ -710,7 +701,7 @@ impl DefaultPhysicalPlanner {
physical_left,
physical_right,
join_on,
- &physical_join_type,
+ join_type,
PartitionMode::CollectLeft,
)?))
}
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index e34f0e6..f89ba3f 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -27,8 +27,8 @@ use crate::datasource::TableProvider;
use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits};
use crate::logical_plan::Expr::Alias;
use crate::logical_plan::{
- and, lit, union_with_alias, Column, DFSchema, Expr, LogicalPlan,
LogicalPlanBuilder,
- Operator, PlanType, StringifiedPlan, ToDFSchema,
+ and, col, lit, normalize_col, union_with_alias, Column, DFSchema, Expr,
LogicalPlan,
+ LogicalPlanBuilder, Operator, PlanType, StringifiedPlan, ToDFSchema,
};
use crate::prelude::JoinType;
use crate::scalar::ScalarValue;
@@ -496,12 +496,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let right_schema = right.schema();
let mut join_keys = vec![];
for (l, r) in &possible_join_keys {
- if left_schema.field_from_qualified_column(l).is_ok()
- &&
right_schema.field_from_qualified_column(r).is_ok()
+ if left_schema.field_from_column(l).is_ok()
+ && right_schema.field_from_column(r).is_ok()
{
join_keys.push((l.clone(), r.clone()));
- } else if
left_schema.field_from_qualified_column(r).is_ok()
- &&
right_schema.field_from_qualified_column(l).is_ok()
+ } else if left_schema.field_from_column(r).is_ok()
+ && right_schema.field_from_column(l).is_ok()
{
join_keys.push((r.clone(), l.clone()));
}
@@ -579,7 +579,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// SELECT c1 AS m FROM t HAVING c1 > 10;
// SELECT c1, MAX(c2) AS m FROM t GROUP BY c1 HAVING MAX(c2)
> 10;
//
- resolve_aliases_to_exprs(&having_expr, &alias_map)
+ let having_expr = resolve_aliases_to_exprs(&having_expr,
&alias_map)?;
+ normalize_col(having_expr, &projected_plan)
})
.transpose()?;
@@ -603,6 +604,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let group_by_expr =
resolve_positions_to_exprs(&group_by_expr, &select_exprs)
.unwrap_or(group_by_expr);
+ let group_by_expr = normalize_col(group_by_expr,
&projected_plan)?;
self.validate_schema_satisfies_exprs(
plan.schema(),
&[group_by_expr.clone()],
@@ -681,13 +683,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<Vec<Expr>> {
let input_schema = plan.schema();
- Ok(projection
+ projection
.iter()
.map(|expr| self.sql_select_to_rex(expr, input_schema))
.collect::<Result<Vec<Expr>>>()?
.iter()
.flat_map(|expr| expand_wildcard(expr, input_schema))
- .collect::<Vec<Expr>>())
+ .map(|expr| normalize_col(expr, plan))
+ .collect::<Result<Vec<Expr>>>()
}
/// Wrap a plan in a projection
@@ -835,20 +838,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
find_column_exprs(exprs)
.iter()
.try_for_each(|col| match col {
- Expr::Column(col) => {
- match &col.relation {
- Some(r) => schema.field_with_qualified_name(r,
&col.name),
- None => schema.field_with_unqualified_name(&col.name),
+ Expr::Column(col) => match &col.relation {
+ Some(r) => {
+ schema.field_with_qualified_name(r, &col.name)?;
+ Ok(())
+ }
+ None => {
+ if
!schema.fields_with_unqualified_name(&col.name).is_empty() {
+ Ok(())
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "No field with unqualified name '{}'",
+ &col.name
+ )))
+ }
}
- .map_err(|_| {
- DataFusionError::Plan(format!(
- "Invalid identifier '{}' for schema {}",
- col,
- schema.to_string()
- ))
- })?;
- Ok(())
}
+ .map_err(|_: DataFusionError| {
+ DataFusionError::Plan(format!(
+ "Invalid identifier '{}' for schema {}",
+ col,
+ schema.to_string()
+ ))
+ }),
_ => Err(DataFusionError::Internal("Not a
column".to_string())),
})
}
@@ -926,11 +938,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let var_names = vec![id.value.clone()];
Ok(Expr::ScalarVariable(var_names))
} else {
- Ok(Expr::Column(
- schema
- .field_with_unqualified_name(&id.value)?
- .qualified_column(),
- ))
+ // create a column expression based on raw user input,
this column will be
+ // normalized with qualifer later by the SQL planner.
+ Ok(col(&id.value))
}
}
@@ -1672,7 +1682,7 @@ mod tests {
let err = logical_plan(sql).expect_err("query should have failed");
assert!(matches!(
err,
- DataFusionError::Plan(msg) if msg.contains("No field with
unqualified name 'doesnotexist'"),
+ DataFusionError::Plan(msg) if msg.contains("Invalid identifier
'#doesnotexist' for schema "),
));
}
@@ -1730,7 +1740,7 @@ mod tests {
let err = logical_plan(sql).expect_err("query should have failed");
assert!(matches!(
err,
- DataFusionError::Plan(msg) if msg.contains("No field with
unqualified name 'doesnotexist'"),
+ DataFusionError::Plan(msg) if msg.contains("Invalid identifier
'#doesnotexist' for schema "),
));
}
@@ -1740,7 +1750,7 @@ mod tests {
let err = logical_plan(sql).expect_err("query should have failed");
assert!(matches!(
err,
- DataFusionError::Plan(msg) if msg.contains("No field with
unqualified name 'x'"),
+ DataFusionError::Plan(msg) if msg.contains("Invalid identifier
'#x' for schema "),
));
}
@@ -2211,7 +2221,7 @@ mod tests {
let err = logical_plan(sql).expect_err("query should have failed");
assert!(matches!(
err,
- DataFusionError::Plan(msg) if msg.contains("No field with
unqualified name 'doesnotexist'"),
+ DataFusionError::Plan(msg) if msg.contains("Invalid identifier
'#doesnotexist' for schema "),
));
}
@@ -2301,7 +2311,7 @@ mod tests {
let err = logical_plan(sql).expect_err("query should have failed");
assert!(matches!(
err,
- DataFusionError::Plan(msg) if msg.contains("No field with
unqualified name 'doesnotexist'"),
+ DataFusionError::Plan(msg) if msg.contains("Column #doesnotexist
not found in provided schemas"),
));
}
@@ -2311,7 +2321,7 @@ mod tests {
let err = logical_plan(sql).expect_err("query should have failed");
assert!(matches!(
err,
- DataFusionError::Plan(msg) if msg.contains("No field with
unqualified name 'doesnotexist'"),
+ DataFusionError::Plan(msg) if msg.contains("Invalid identifier
'#doesnotexist' for schema "),
));
}
@@ -2757,7 +2767,7 @@ mod tests {
JOIN person as person2 \
USING (id)";
let expected = "Projection: #person.first_name, #person.id\
- \n Join: #person.id = #person2.id\
+ \n Join: Using #person.id = #person2.id\
\n TableScan: person projection=None\
\n TableScan: person2 projection=None";
quick_test(sql, expected);
diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs
index df3aec4..b791551 100644
--- a/datafusion/src/test/mod.rs
+++ b/datafusion/src/test/mod.rs
@@ -110,14 +110,19 @@ pub fn aggr_test_schema() -> SchemaRef {
]))
}
-/// some tests share a common table
-pub fn test_table_scan() -> Result<LogicalPlan> {
+/// some tests share a common table with different names
+pub fn test_table_scan_with_name(name: &str) -> Result<LogicalPlan> {
let schema = Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::UInt32, false),
Field::new("c", DataType::UInt32, false),
]);
- LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build()
+ LogicalPlanBuilder::scan_empty(Some(name), &schema, None)?.build()
+}
+
+/// some tests share a common table
+pub fn test_table_scan() -> Result<LogicalPlan> {
+ test_table_scan_with_name("test")
}
pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) {